From 68134a60241cb3b778f2b27699f98eca87bd940a Mon Sep 17 00:00:00 2001
From: Thai Nguyen <thaink@google.com>
Date: Wed, 11 Nov 2020 18:58:32 -0800
Subject: [PATCH] Support string input in TFLite StridedSlice kernel

PiperOrigin-RevId: 341957475
Change-Id: I96c79ba6a95b09861fe90120f3b6431f3d8e3a53
---
 tensorflow/compiler/mlir/lite/ir/tfl_ops.td   |  4 +-
 .../compiler/mlir/lite/tests/legalize-tf.mlir |  7 +++
 tensorflow/compiler/mlir/lite/tests/ops.mlir  |  6 +++
 .../internal/reference/strided_slice.h        | 31 +++++++++--
 tensorflow/lite/kernels/register.cc           |  2 +-
 tensorflow/lite/kernels/strided_slice.cc      | 14 +++--
 tensorflow/lite/kernels/strided_slice_test.cc | 53 +++++++++++++++++++
 .../lite/testing/op_tests/strided_slice.py    | 14 +++++
 .../lite/tools/versioning/op_version.cc       |  3 ++
 .../lite/tools/versioning/runtime_version.cc  |  1 +
 10 files changed, 123 insertions(+), 12 deletions(-)

diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index ae2e424ec81..a4f67c5afe9 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
   }];
 
   let arguments = (ins
-    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
+    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
     TFL_I32Tensor:$begin,
     TFL_I32Tensor:$end,
     TFL_I32Tensor:$strides,
@@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
   );
 
   let results = (outs
-    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
+    TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
   );
 
   let hasOptions = 1;
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index 5e36f4af802..dd8bbdb8372 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -1122,6 +1122,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
   // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
 }
 
+func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
+  %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
+  return %0 : tensor<1x2x2x5x!tf.string>
+  // CHECK-LABEL: strided_slice_with_string
+  // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
+}
+
 func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
   %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
   return %0 : tensor<?x3x5xf32>
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index 3a98f6db0c4..a3aea7bd593 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>,
   return %0 : tensor<1x2x2x5x!tf.quint8>
 }
 
+// CHECK-LABEL: testStridedSliceWithString
+func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
+  %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
+  return %0 : tensor<1x2x2x5x!tf.string>
+}
+
 // -----
 
 func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {
diff --git a/tensorflow/lite/kernels/internal/reference/strided_slice.h b/tensorflow/lite/kernels/internal/reference/strided_slice.h
index 8b6f0c13da1..24aa798d9c9 100644
--- a/tensorflow/lite/kernels/internal/reference/strided_slice.h
+++ b/tensorflow/lite/kernels/internal/reference/strided_slice.h
@@ -17,18 +17,19 @@ limitations under the License.
 
 #include "tensorflow/lite/kernels/internal/common.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/kernels/internal/portable_tensor.h"
 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 
 namespace tflite {
 
 namespace reference_ops {
+
 template <typename T>
 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
                          const RuntimeShape& unextended_input_shape,
-                         const T* input_data,
                          const RuntimeShape& unextended_output_shape,
-                         T* output_data) {
+                         SequentialTensorWriter<T>* writer) {
   using strided_slice::LoopCondition;
   using strided_slice::StartForAxis;
   using strided_slice::StopForAxis;
@@ -57,7 +58,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
   const int start_4 = StartForAxis(params_copy, input_shape, 4);
   const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
 
-  T* out_ptr = output_data;
   for (int offset_0 = start_0 * input_shape.Dims(1),
            end_0 = stop_0 * input_shape.Dims(1),
            step_0 = params_copy.strides[0] * input_shape.Dims(1);
@@ -81,13 +81,36 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
           for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
                !LoopCondition(offset_4, end_4, params_copy.strides[4]);
                offset_4 += params_copy.strides[4]) {
-            *out_ptr++ = input_data[offset_4];
+            writer->Write(offset_4);
           }
         }
       }
     }
   }
 }
+
+template <typename T>
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+                         const RuntimeShape& unextended_input_shape,
+                         const T* input_data,
+                         const RuntimeShape& unextended_output_shape,
+                         T* output_data) {
+  SequentialTensorWriter<T> writer(input_data, output_data);
+  StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
+                  &writer);
+}
+
+template <typename T>
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+                         const RuntimeShape& unextended_input_shape,
+                         const TfLiteTensor* input,
+                         const RuntimeShape& unextended_output_shape,
+                         TfLiteTensor* output) {
+  SequentialTensorWriter<T> writer(input, output);
+  StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
+                  &writer);
+}
+
 }  // namespace reference_ops
 }  // namespace tflite
 
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index cd0c297a545..9aa14e579d4 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -157,7 +157,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
              /* max_version = */ 2);
   AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
              /* min_version = */ 1,
-             /* max_version = */ 4);
+             /* max_version = */ 5);
   AddBuiltin(BuiltinOperator_EXP, Register_EXP());
   AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
              /* min_version = */ 1,
diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc
index d10e99c1997..3f2fd580a0b 100644
--- a/tensorflow/lite/kernels/strided_slice.cc
+++ b/tensorflow/lite/kernels/strided_slice.cc
@@ -190,11 +190,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   }
   StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
 
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type)                    \
-  kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
-                            GetTensorData<data_type>(op_context.input),  \
-                            GetTensorShape(op_context.output),           \
-                            GetTensorData<data_type>(op_context.output))
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type)                \
+  kernel_type::StridedSlice<data_type>(                              \
+      op_params, GetTensorShape(op_context.input), op_context.input, \
+      GetTensorShape(op_context.output), op_context.output)
 
   switch (op_context.input->type) {
     case kTfLiteFloat32:
@@ -232,6 +231,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
         TF_LITE_STRIDED_SLICE(reference_ops, bool);
       }
       break;
+    case kTfLiteString:
+      if (kernel_type == kReference) {
+        TF_LITE_STRIDED_SLICE(reference_ops, string);
+      }
+      break;
     default:
       TF_LITE_KERNEL_LOG(context,
                          "Type %s is currently not supported "
diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc
index d66cf884474..98521b889f9 100644
--- a/tensorflow/lite/kernels/strided_slice_test.cc
+++ b/tensorflow/lite/kernels/strided_slice_test.cc
@@ -55,6 +55,9 @@ class StridedSliceOpModel : public SingleOpModel {
   void SetInput(const std::vector<input_type> data) {
     PopulateTensor<input_type>(input_, data);
   }
+  void SetStringInput(std::initializer_list<string> data) {
+    PopulateStringTensor(input_, data);
+  }
   void SetBegin(std::initializer_list<int32_t> data) {
     PopulateTensor<int32_t>(begin_, data);
   }
@@ -68,6 +71,9 @@ class StridedSliceOpModel : public SingleOpModel {
   std::vector<input_type> GetOutput() {
     return ExtractVector<input_type>(output_);
   }
+  std::vector<string> GetStringOutput() {
+    return ExtractVector<string>(output_);
+  }
   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
 
  private:
@@ -692,5 +698,52 @@ TYPED_TEST(StridedSliceOpTest, In3D_Backward) {
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
 }
 
+TEST(StridedSliceOpTest, In1D_String_NegativeBegin) {
+  StridedSliceOpModel<std::string> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+  m.SetStringInput({"a", "b", "c", "d"});
+  m.SetBegin({-3});
+  m.SetEnd({3});
+  m.SetStrides({1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+  EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"b", "c"}));
+}
+
+TEST(StridedSliceOpTest, In3D_String_BackwardSmallBegin) {
+  StridedSliceOpModel<std::string> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
+  m.SetStringInput({"a", "b"});
+  m.SetBegin({1});
+  m.SetEnd({0});
+  m.SetStrides({1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
+}
+
+TEST(StridedSliceOpTest, In3D_String_SmallBeginWithhrinkAxis1) {
+  StridedSliceOpModel<std::string> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+  m.SetStringInput(
+      {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"});
+  m.SetBegin({0});
+  m.SetEnd({1});
+  m.SetStrides({1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
+  EXPECT_THAT(m.GetStringOutput(),
+              ElementsAreArray({"1", "2", "3", "4", "5", "6"}));
+}
+
+TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) {
+  StridedSliceOpModel<std::string> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
+                                     1);
+  m.SetStringInput({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11",
+                    "12", "13", "14", "15", "16"});
+  m.SetBegin({0, 0, 0, 0, 0});
+  m.SetEnd({2, 1, 2, 1, 2});
+  m.SetStrides({1, 1, 1, 1, 1});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
+  EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"}));
+}
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/testing/op_tests/strided_slice.py b/tensorflow/lite/testing/op_tests/strided_slice.py
index 3a04354c202..8668e139f34 100644
--- a/tensorflow/lite/testing/op_tests/strided_slice.py
+++ b/tensorflow/lite/testing/op_tests/strided_slice.py
@@ -230,6 +230,20 @@ def make_strided_slice_tests(options):
             "shrink_axis_mask": [0],
             "constant_indices": [True, False],
             "fully_quantize": [False],
+        },
+        # String input.
+        {
+            "dtype": [tf.string],
+            "index_type": [tf.int32],
+            "input_shape": [[12, 2, 2, 5]],
+            "begin": [[0, 0, 0, 0]],
+            "end": [[8, 2, 2, 3]],
+            "strides": [[2, 1, 3, 1]],
+            "begin_mask": [8],
+            "end_mask": [3],
+            "shrink_axis_mask": [None, -1],
+            "constant_indices": [True, False],
+            "fully_quantize": [False],
         }
     ]
   _make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index 6b9ff9c1dcf..1f84c261cdb 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -387,6 +387,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
       return 1;
 
     case BuiltinOperator_STRIDED_SLICE:
+      if (op_sig.input_types.at(0) == TensorType_STRING) {
+        return 5;
+      }
       if (op_sig.options.single_input_op.num_dims > 4) {
         return 4;
       }
diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc
index 2e71882f469..fa0b01fc939 100644
--- a/tensorflow/lite/tools/versioning/runtime_version.cc
+++ b/tensorflow/lite/tools/versioning/runtime_version.cc
@@ -218,6 +218,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
               {{BuiltinOperator_STRIDED_SLICE, 2}, "1.14.0"},
               {{BuiltinOperator_STRIDED_SLICE, 3}, "2.1.0"},
               {{BuiltinOperator_STRIDED_SLICE, 4}, "2.2.0"},
+              {{BuiltinOperator_STRIDED_SLICE, 5}, kPendingReleaseVersion},
               {{BuiltinOperator_TOPK_V2, 1}, "1.7.0"},
               {{BuiltinOperator_TOPK_V2, 2}, "1.14.0"},
               {{BuiltinOperator_ARG_MAX, 1}, "1.9.0"},