From 0c4d4090c9b02d356344e25afa2542e6746846d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 Apr 2019 17:55:17 -0700 Subject: [PATCH] Add string support to slice. PiperOrigin-RevId: 242045915 --- tensorflow/lite/kernels/internal/BUILD | 7 +++ .../internal/optimized/optimized_ops.h | 29 ++++++++++--- .../internal/reference/reference_ops.h | 27 ++++++++++-- tensorflow/lite/kernels/internal/tensor.h | 43 +++++++++++++++++++ tensorflow/lite/kernels/register.cc | 5 ++- tensorflow/lite/kernels/slice.cc | 43 ++++++++++--------- tensorflow/lite/kernels/slice_test.cc | 21 +++++++++ .../lite/testing/generate_examples_lib.py | 6 +-- tensorflow/lite/toco/tflite/operator.cc | 6 ++- tensorflow/lite/toco/tflite/operator_test.cc | 12 ++++++ 10 files changed, 161 insertions(+), 38 deletions(-) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 840be3a1ddb..6a9339f3e58 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -189,6 +189,7 @@ cc_library( ":types", ":reference_base", ":round", + ":tensor", ":tensor_utils", "//third_party/eigen3", "@gemmlowp", @@ -222,6 +223,7 @@ cc_library( deps = [ ":quantization_util", ":strided_slice_logic", + ":tensor", ":tensor_utils", ":types", ":legacy_types", @@ -258,6 +260,7 @@ cc_library( ":tensor", ":types", "//tensorflow/core/kernels:eigen_spatial_convolutions-inl", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", "//third_party/eigen3", ], @@ -341,6 +344,7 @@ cc_library( ":quantization_util", ":round", ":strided_slice_logic", + ":tensor", ":types", "@gemmlowp", "//tensorflow/lite/c:c_api_internal", @@ -376,6 +380,7 @@ cc_library( ":round", ":strided_slice_logic", ":legacy_types", + ":tensor", ":types", "@gemmlowp", "//tensorflow/lite/c:c_api_internal", @@ -401,6 +406,7 @@ cc_library( ], deps = [ ":types", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", ], ) @@ -414,6 +420,7 @@ cc_library( ], deps = [ ":types", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", ], ) diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 3d1845246a0..e0231785c9c 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -34,12 +34,14 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" +#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -5958,8 +5960,9 @@ inline void PadImageStyle(const tflite::PadParams& op_params, template inline void Slice(const tflite::SliceParams& op_params, - const RuntimeShape& input_shape, const T* input_data, - const RuntimeShape& output_shape, T* output_data) { + const RuntimeShape& input_shape, + const RuntimeShape& output_shape, + SequentialTensorWriter* writer) { gemmlowp::ScopedProfilingLabel label("Slice"); const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape); // TODO(dkalenichenko): This op only supports 4D tensors or smaller. @@ -5985,20 +5988,32 @@ inline void Slice(const tflite::SliceParams& op_params, ? ext_shape.Dims(3) - start_d : start_d + op_params.size[size_count - 1]; - T* out_ptr = output_data; for (int in_b = start_b; in_b < stop_b; ++in_b) { for (int in_h = start_h; in_h < stop_h; ++in_h) { for (int in_w = start_w; in_w < stop_w; ++in_w) { const int len = stop_d - start_d; - memcpy(out_ptr, - input_data + Offset(ext_shape, in_b, in_h, in_w, start_d), - len * sizeof(T)); - out_ptr += len; + writer->WriteN(Offset(ext_shape, in_b, in_h, in_w, start_d), len); } } } } +template +inline void Slice(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { + SequentialTensorWriter writer(input_data, output_data); + return Slice(op_params, input_shape, output_shape, &writer); +} + +template +inline void Slice(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, const TfLiteTensor* input, + const RuntimeShape& output_shape, TfLiteTensor* output) { + SequentialTensorWriter writer(input, output); + return Slice(op_params, input_shape, output_shape, &writer); +} + template void Minimum(const RuntimeShape& input1_shape, const T* input1_data, const T* input2_data, const RuntimeShape& output_shape, diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 85832768642..a24a47d3188 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -27,6 +27,7 @@ limitations under the License. #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/conv.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" +#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -3246,6 +3248,7 @@ inline void PadImageStyle(const tflite::PadParams& op_params, Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape, output_data); } + template inline void StridedSlice(const tflite::StridedSliceParams& op_params, const RuntimeShape& unextended_input_shape, @@ -3301,8 +3304,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, template inline void Slice(const tflite::SliceParams& op_params, - const RuntimeShape& input_shape, const T* input_data, - const RuntimeShape& output_shape, T* output_data) { + const RuntimeShape& input_shape, + const RuntimeShape& output_shape, + SequentialTensorWriter* writer) { const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape); // TODO(dkalenichenko): This op only supports 4D tensors or smaller. TFLITE_DCHECK_LE(op_params.begin_count, 4); @@ -3327,18 +3331,33 @@ inline void Slice(const tflite::SliceParams& op_params, ? ext_shape.Dims(3) - start_d : start_d + op_params.size[size_count - 1]; - T* out_ptr = output_data; for (int in_b = start_b; in_b < stop_b; ++in_b) { for (int in_h = start_h; in_h < stop_h; ++in_h) { for (int in_w = start_w; in_w < stop_w; ++in_w) { for (int in_d = start_d; in_d < stop_d; ++in_d) { - *out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)]; + writer->Write(Offset(ext_shape, in_b, in_h, in_w, in_d)); } } } } } +template +inline void Slice(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { + SequentialTensorWriter writer(input_data, output_data); + return Slice(op_params, input_shape, output_shape, &writer); +} + +template +inline void Slice(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, const TfLiteTensor* input, + const RuntimeShape& output_shape, TfLiteTensor* output) { + SequentialTensorWriter writer(input, output); + return Slice(op_params, input_shape, output_shape, &writer); +} + template inline void Exp(const T* input_data, const size_t num_elements, T* output_data) { diff --git a/tensorflow/lite/kernels/internal/tensor.h b/tensorflow/lite/kernels/internal/tensor.h index b806753d886..94b0f7d8703 100644 --- a/tensorflow/lite/kernels/internal/tensor.h +++ b/tensorflow/lite/kernels/internal/tensor.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/string_util.h" namespace tflite { @@ -109,6 +110,48 @@ class VectorOfQuantizedTensors : public VectorOfTensors { std::vector scale_; }; +// Writes randomly accessed values from `input` sequentially into `output`. +template +class SequentialTensorWriter { + public: + SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) { + input_data_ = GetTensorData(input); + output_ptr_ = GetTensorData(output); + } + SequentialTensorWriter(const T* input_data, T* output_data) + : input_data_(input_data), output_ptr_(output_data) {} + + void Write(int position) { *output_ptr_++ = input_data_[position]; } + void WriteN(int position, int len) { + memcpy(output_ptr_, &input_data_[position], sizeof(T) * len); + output_ptr_ += len; + } + + private: + const T* input_data_; + T* output_ptr_; +}; + +template <> +class SequentialTensorWriter { + public: + SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) + : input_(input), output_(output) {} + ~SequentialTensorWriter() { buffer_.WriteToTensor(output_, nullptr); } + + void Write(int position) { this->WriteN(position, 1); } + void WriteN(int position, int len) { + for (int i = 0; i < len; i++) { + buffer_.AddString(GetString(input_, position + i)); + } + } + + private: + const TfLiteTensor* input_; + TfLiteTensor* output_; + DynamicBuffer buffer_; +}; + } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 3a80230536f..c30eefc2b92 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -324,8 +324,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(), /* min_version */ 1, /* max_version */ 2); - AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version */ 1, - /* max_version */ 2); + AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), + /* min_version */ 1, + /* max_version */ 3); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); diff --git a/tensorflow/lite/kernels/slice.cc b/tensorflow/lite/kernels/slice.cc index 8472572d7e2..3b4ee40ed70 100644 --- a/tensorflow/lite/kernels/slice.cc +++ b/tensorflow/lite/kernels/slice.cc @@ -172,27 +172,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // The dimensions in the kernel used to be in reverse-order, and TFLite // arranged the begins and sizes vectors accordingly. This macro incorporates // the needed reversing. -#define TF_LITE_SLICE(data_type, kernel_type) \ - { \ - TF_LITE_ENSURE_EQ(context, begins.size(), 4); \ - TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \ - tflite::SliceParams op_params; \ - op_params.begin_count = 4; \ - op_params.size_count = 4; \ - for (int i = 0; i < 4; ++i) { \ - op_params.begin[i] = begins[3 - i]; \ - op_params.size[i] = sizes[3 - i]; \ - } \ - \ - if (kernel_type == kGenericOptimized) { \ - optimized_ops::Slice( \ - op_params, GetTensorShape(input), GetTensorData(input), \ - GetTensorShape(output), GetTensorData(output)); \ - } else { \ - reference_ops::Slice( \ - op_params, GetTensorShape(input), GetTensorData(input), \ - GetTensorShape(output), GetTensorData(output)); \ - } \ +#define TF_LITE_SLICE(data_type, kernel_type) \ + { \ + TF_LITE_ENSURE_EQ(context, begins.size(), 4); \ + TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \ + tflite::SliceParams op_params; \ + op_params.begin_count = 4; \ + op_params.size_count = 4; \ + for (int i = 0; i < 4; ++i) { \ + op_params.begin[i] = begins[3 - i]; \ + op_params.size[i] = sizes[3 - i]; \ + } \ + \ + if (kernel_type == kGenericOptimized) { \ + optimized_ops::Slice(op_params, GetTensorShape(input), input, \ + GetTensorShape(output), output); \ + } else { \ + reference_ops::Slice(op_params, GetTensorShape(input), input, \ + GetTensorShape(output), output); \ + } \ } switch (input->type) { @@ -214,6 +212,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteBool: TF_LITE_SLICE(bool, kernel_type); break; + case kTfLiteString: + TF_LITE_SLICE(string, kernel_type); + break; default: context->ReportError( context, "Type %d is currently not supported by Slice.", input->type); diff --git a/tensorflow/lite/kernels/slice_test.cc b/tensorflow/lite/kernels/slice_test.cc index 102218ba23c..b9b88215f2d 100644 --- a/tensorflow/lite/kernels/slice_test.cc +++ b/tensorflow/lite/kernels/slice_test.cc @@ -42,6 +42,9 @@ class SliceOpModel : public SingleOpModel { void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetStringInput(std::vector data) { + PopulateStringTensor(input_, data); + } void SetBegin(std::initializer_list data) { PopulateTensor(begin_, data); } @@ -185,6 +188,24 @@ TEST(SliceOpTest, SliceInt8) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); } +TEST(SliceOpTest, SliceString) { + SliceOpModel m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32, + TensorType_STRING); + m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", // + "0,1,0,0", "0,1,1,0", "0,1,2,0", // + "1,0,0,0", "1,0,1,0", "1,0,2,0", // + "1,1,0,0", "1,1,1,0", "1,1,2,0", // + "2,0,0,0", "2,0,1,0", "2,0,2,0", // + "2,1,0,0", "2,1,1,0", "2,1,2,0"}); + m.SetBegin({1, 0, 0, 0}); + m.SetSize({2, 1, -1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({"1,0,0,0", "1,0,1,0", "1,0,2,0", // + "2,0,0,0", "2,0,1,0", "2,0,2,0"})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 69878a0233c..22699df5291 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -3744,7 +3744,7 @@ def make_slice_tests(options): test_parameters = [ # 4-D { - "dtype": [tf.float32, tf.int32, tf.int64], + "dtype": [tf.float32, tf.int32, tf.int64, tf.string], "index_type": [tf.int32, tf.int64], "input_shape": [[12, 2, 2, 5]], "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], @@ -3752,7 +3752,7 @@ def make_slice_tests(options): }, # 2-D { - "dtype": [tf.float32, tf.int32, tf.int64], + "dtype": [tf.float32, tf.int32, tf.int64, tf.string], "index_type": [tf.int32, tf.int64], "input_shape": [[2, 3]], "begin": [[0, 0], [1, 0]], @@ -3795,7 +3795,7 @@ def make_slice_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=18) + expected_tf_failures=24) @register_make_test_function() diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index bc8f54cb26f..8b292af7085 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1709,10 +1709,14 @@ class Slice : public SimpleOperator { int GetVersion(const OperatorSignature& op_signature) const override { const string& input_name = op_signature.op->inputs[0]; const Array& input_array = op_signature.model->GetArray(input_name); - // Version 2 supports signed int8 input types. if (input_array.data_type == ArrayDataType::kInt8) { + // Version 2 supports signed int8 input types. return 2; } + if (input_array.data_type == ArrayDataType::kString) { + // Version 3 supports string input types. + return 3; + } return 1; } }; diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index a5b1efef5a1..a8b3bafd6f7 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -817,6 +817,18 @@ TEST_F(OperatorTest, VersioningSpaceToDepthTest) { TEST_F(OperatorTest, VersioningSliceTest) { SimpleVersioningTest(); + + // Check that a string input results in a version 3 op. + SliceOperator op; + op.inputs = {"input1"}; + auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); + const BaseOperator* base_op = operator_by_type_map.at(op.type).get(); + + Model string_model; + Array& string_array = string_model.GetOrCreateArray(op.inputs[0]); + string_array.data_type = ArrayDataType::kString; + OperatorSignature string_signature = {.op = &op, .model = &string_model}; + EXPECT_EQ(base_op->GetVersion(string_signature), 3); } TEST_F(OperatorTest, VersioningLogisticTest) {