Support string input in TFLite StridedSlice kernel
PiperOrigin-RevId: 341957475 Change-Id: I96c79ba6a95b09861fe90120f3b6431f3d8e3a53
This commit is contained in:
parent
a802acb1ef
commit
68134a6024
tensorflow
compiler/mlir/lite
lite
kernels
testing/op_tests
tools/versioning
@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
|
||||
TFL_I32Tensor:$begin,
|
||||
TFL_I32Tensor:$end,
|
||||
TFL_I32Tensor:$strides,
|
||||
@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
@ -1122,6 +1122,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
return %0 : tensor<1x2x2x5x!tf.string>
|
||||
// CHECK-LABEL: strided_slice_with_string
|
||||
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
}
|
||||
|
||||
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
|
@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>,
|
||||
return %0 : tensor<1x2x2x5x!tf.quint8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testStridedSliceWithString
|
||||
func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
return %0 : tensor<1x2x2x5x!tf.string>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {
|
||||
|
@ -17,18 +17,19 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
SequentialTensorWriter<T>* writer) {
|
||||
using strided_slice::LoopCondition;
|
||||
using strided_slice::StartForAxis;
|
||||
using strided_slice::StopForAxis;
|
||||
@ -57,7 +58,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const int start_4 = StartForAxis(params_copy, input_shape, 4);
|
||||
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
|
||||
|
||||
T* out_ptr = output_data;
|
||||
for (int offset_0 = start_0 * input_shape.Dims(1),
|
||||
end_0 = stop_0 * input_shape.Dims(1),
|
||||
step_0 = params_copy.strides[0] * input_shape.Dims(1);
|
||||
@ -81,13 +81,36 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
|
||||
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
|
||||
offset_4 += params_copy.strides[4]) {
|
||||
*out_ptr++ = input_data[offset_4];
|
||||
writer->Write(offset_4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||
&writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const TfLiteTensor* input,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
TfLiteTensor* output) {
|
||||
SequentialTensorWriter<T> writer(input, output);
|
||||
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||
&writer);
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -157,7 +157,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
|
||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
|
||||
/* min_version = */ 1,
|
||||
|
@ -190,11 +190,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
|
||||
|
||||
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
|
||||
kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
|
||||
GetTensorData<data_type>(op_context.input), \
|
||||
GetTensorShape(op_context.output), \
|
||||
GetTensorData<data_type>(op_context.output))
|
||||
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
|
||||
kernel_type::StridedSlice<data_type>( \
|
||||
op_params, GetTensorShape(op_context.input), op_context.input, \
|
||||
GetTensorShape(op_context.output), op_context.output)
|
||||
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -232,6 +231,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, bool);
|
||||
}
|
||||
break;
|
||||
case kTfLiteString:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, string);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Type %s is currently not supported "
|
||||
|
@ -55,6 +55,9 @@ class StridedSliceOpModel : public SingleOpModel {
|
||||
void SetInput(const std::vector<input_type> data) {
|
||||
PopulateTensor<input_type>(input_, data);
|
||||
}
|
||||
void SetStringInput(std::initializer_list<string> data) {
|
||||
PopulateStringTensor(input_, data);
|
||||
}
|
||||
void SetBegin(std::initializer_list<int32_t> data) {
|
||||
PopulateTensor<int32_t>(begin_, data);
|
||||
}
|
||||
@ -68,6 +71,9 @@ class StridedSliceOpModel : public SingleOpModel {
|
||||
std::vector<input_type> GetOutput() {
|
||||
return ExtractVector<input_type>(output_);
|
||||
}
|
||||
std::vector<string> GetStringOutput() {
|
||||
return ExtractVector<string>(output_);
|
||||
}
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
@ -692,5 +698,52 @@ TYPED_TEST(StridedSliceOpTest, In3D_Backward) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In1D_String_NegativeBegin) {
|
||||
StridedSliceOpModel<std::string> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetStringInput({"a", "b", "c", "d"});
|
||||
m.SetBegin({-3});
|
||||
m.SetEnd({3});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
|
||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"b", "c"}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_String_BackwardSmallBegin) {
|
||||
StridedSliceOpModel<std::string> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
|
||||
m.SetStringInput({"a", "b"});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({0});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_String_SmallBeginWithhrinkAxis1) {
|
||||
StridedSliceOpModel<std::string> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
|
||||
m.SetStringInput(
|
||||
{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"});
|
||||
m.SetBegin({0});
|
||||
m.SetEnd({1});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
|
||||
EXPECT_THAT(m.GetStringOutput(),
|
||||
ElementsAreArray({"1", "2", "3", "4", "5", "6"}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) {
|
||||
StridedSliceOpModel<std::string> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
|
||||
1);
|
||||
m.SetStringInput({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11",
|
||||
"12", "13", "14", "15", "16"});
|
||||
m.SetBegin({0, 0, 0, 0, 0});
|
||||
m.SetEnd({2, 1, 2, 1, 2});
|
||||
m.SetStrides({1, 1, 1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
|
||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -230,6 +230,20 @@ def make_strided_slice_tests(options):
|
||||
"shrink_axis_mask": [0],
|
||||
"constant_indices": [True, False],
|
||||
"fully_quantize": [False],
|
||||
},
|
||||
# String input.
|
||||
{
|
||||
"dtype": [tf.string],
|
||||
"index_type": [tf.int32],
|
||||
"input_shape": [[12, 2, 2, 5]],
|
||||
"begin": [[0, 0, 0, 0]],
|
||||
"end": [[8, 2, 2, 3]],
|
||||
"strides": [[2, 1, 3, 1]],
|
||||
"begin_mask": [8],
|
||||
"end_mask": [3],
|
||||
"shrink_axis_mask": [None, -1],
|
||||
"constant_indices": [True, False],
|
||||
"fully_quantize": [False],
|
||||
}
|
||||
]
|
||||
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)
|
||||
|
@ -387,6 +387,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
if (op_sig.input_types.at(0) == TensorType_STRING) {
|
||||
return 5;
|
||||
}
|
||||
if (op_sig.options.single_input_op.num_dims > 4) {
|
||||
return 4;
|
||||
}
|
||||
|
@ -218,6 +218,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_STRIDED_SLICE, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_STRIDED_SLICE, 3}, "2.1.0"},
|
||||
{{BuiltinOperator_STRIDED_SLICE, 4}, "2.2.0"},
|
||||
{{BuiltinOperator_STRIDED_SLICE, 5}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_TOPK_V2, 1}, "1.7.0"},
|
||||
{{BuiltinOperator_TOPK_V2, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_ARG_MAX, 1}, "1.9.0"},
|
||||
|
Loading…
Reference in New Issue
Block a user