Support string input in TFLite StridedSlice kernel

PiperOrigin-RevId: 341957475
Change-Id: I96c79ba6a95b09861fe90120f3b6431f3d8e3a53
This commit is contained in:
Thai Nguyen 2020-11-11 18:58:32 -08:00 committed by TensorFlower Gardener
parent a802acb1ef
commit 68134a6024
10 changed files with 123 additions and 12 deletions
tensorflow

View File

@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
TFL_I32Tensor:$begin,
TFL_I32Tensor:$end,
TFL_I32Tensor:$strides,
@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
);
let hasOptions = 1;

View File

@ -1122,6 +1122,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
}
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
return %0 : tensor<1x2x2x5x!tf.string>
// CHECK-LABEL: strided_slice_with_string
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
}
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>

View File

@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>,
return %0 : tensor<1x2x2x5x!tf.quint8>
}
// CHECK-LABEL: testStridedSliceWithString
func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
return %0 : tensor<1x2x2x5x!tf.string>
}
// -----
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {

View File

@ -17,18 +17,19 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
SequentialTensorWriter<T>* writer) {
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
@ -57,7 +58,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const int start_4 = StartForAxis(params_copy, input_shape, 4);
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
T* out_ptr = output_data;
for (int offset_0 = start_0 * input_shape.Dims(1),
end_0 = stop_0 * input_shape.Dims(1),
step_0 = params_copy.strides[0] * input_shape.Dims(1);
@ -81,13 +81,36 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
offset_4 += params_copy.strides[4]) {
*out_ptr++ = input_data[offset_4];
writer->Write(offset_4);
}
}
}
}
}
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const TfLiteTensor* input,
const RuntimeShape& unextended_output_shape,
TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
} // namespace reference_ops
} // namespace tflite

View File

@ -157,7 +157,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
/* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
/* min_version = */ 1,

View File

@ -190,11 +190,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
GetTensorData<data_type>(op_context.input), \
GetTensorShape(op_context.output), \
GetTensorData<data_type>(op_context.output))
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
kernel_type::StridedSlice<data_type>( \
op_params, GetTensorShape(op_context.input), op_context.input, \
GetTensorShape(op_context.output), op_context.output)
switch (op_context.input->type) {
case kTfLiteFloat32:
@ -232,6 +231,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_STRIDED_SLICE(reference_ops, bool);
}
break;
case kTfLiteString:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, string);
}
break;
default:
TF_LITE_KERNEL_LOG(context,
"Type %s is currently not supported "

View File

@ -55,6 +55,9 @@ class StridedSliceOpModel : public SingleOpModel {
void SetInput(const std::vector<input_type> data) {
PopulateTensor<input_type>(input_, data);
}
void SetStringInput(std::initializer_list<string> data) {
PopulateStringTensor(input_, data);
}
void SetBegin(std::initializer_list<int32_t> data) {
PopulateTensor<int32_t>(begin_, data);
}
@ -68,6 +71,9 @@ class StridedSliceOpModel : public SingleOpModel {
std::vector<input_type> GetOutput() {
return ExtractVector<input_type>(output_);
}
std::vector<string> GetStringOutput() {
return ExtractVector<string>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
@ -692,5 +698,52 @@ TYPED_TEST(StridedSliceOpTest, In3D_Backward) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
}
TEST(StridedSliceOpTest, In1D_String_NegativeBegin) {
StridedSliceOpModel<std::string> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetStringInput({"a", "b", "c", "d"});
m.SetBegin({-3});
m.SetEnd({3});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"b", "c"}));
}
TEST(StridedSliceOpTest, In3D_String_BackwardSmallBegin) {
StridedSliceOpModel<std::string> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
m.SetStringInput({"a", "b"});
m.SetBegin({1});
m.SetEnd({0});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
}
TEST(StridedSliceOpTest, In3D_String_SmallBeginWithhrinkAxis1) {
StridedSliceOpModel<std::string> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetStringInput(
{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"});
m.SetBegin({0});
m.SetEnd({1});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
EXPECT_THAT(m.GetStringOutput(),
ElementsAreArray({"1", "2", "3", "4", "5", "6"}));
}
TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) {
StridedSliceOpModel<std::string> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
1);
m.SetStringInput({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11",
"12", "13", "14", "15", "16"});
m.SetBegin({0, 0, 0, 0, 0});
m.SetEnd({2, 1, 2, 1, 2});
m.SetStrides({1, 1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"}));
}
} // namespace
} // namespace tflite

View File

@ -230,6 +230,20 @@ def make_strided_slice_tests(options):
"shrink_axis_mask": [0],
"constant_indices": [True, False],
"fully_quantize": [False],
},
# String input.
{
"dtype": [tf.string],
"index_type": [tf.int32],
"input_shape": [[12, 2, 2, 5]],
"begin": [[0, 0, 0, 0]],
"end": [[8, 2, 2, 3]],
"strides": [[2, 1, 3, 1]],
"begin_mask": [8],
"end_mask": [3],
"shrink_axis_mask": [None, -1],
"constant_indices": [True, False],
"fully_quantize": [False],
}
]
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)

View File

@ -387,6 +387,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_STRIDED_SLICE:
if (op_sig.input_types.at(0) == TensorType_STRING) {
return 5;
}
if (op_sig.options.single_input_op.num_dims > 4) {
return 4;
}

View File

@ -218,6 +218,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_STRIDED_SLICE, 2}, "1.14.0"},
{{BuiltinOperator_STRIDED_SLICE, 3}, "2.1.0"},
{{BuiltinOperator_STRIDED_SLICE, 4}, "2.2.0"},
{{BuiltinOperator_STRIDED_SLICE, 5}, kPendingReleaseVersion},
{{BuiltinOperator_TOPK_V2, 1}, "1.7.0"},
{{BuiltinOperator_TOPK_V2, 2}, "1.14.0"},
{{BuiltinOperator_ARG_MAX, 1}, "1.9.0"},