Add string support to slice.
PiperOrigin-RevId: 242045915
This commit is contained in:
parent
51a3acd6a0
commit
0c4d4090c9
@ -189,6 +189,7 @@ cc_library(
|
||||
":types",
|
||||
":reference_base",
|
||||
":round",
|
||||
":tensor",
|
||||
":tensor_utils",
|
||||
"//third_party/eigen3",
|
||||
"@gemmlowp",
|
||||
@ -222,6 +223,7 @@ cc_library(
|
||||
deps = [
|
||||
":quantization_util",
|
||||
":strided_slice_logic",
|
||||
":tensor",
|
||||
":tensor_utils",
|
||||
":types",
|
||||
":legacy_types",
|
||||
@ -258,6 +260,7 @@ cc_library(
|
||||
":tensor",
|
||||
":types",
|
||||
"//tensorflow/core/kernels:eigen_spatial_convolutions-inl",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
@ -341,6 +344,7 @@ cc_library(
|
||||
":quantization_util",
|
||||
":round",
|
||||
":strided_slice_logic",
|
||||
":tensor",
|
||||
":types",
|
||||
"@gemmlowp",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
@ -376,6 +380,7 @@ cc_library(
|
||||
":round",
|
||||
":strided_slice_logic",
|
||||
":legacy_types",
|
||||
":tensor",
|
||||
":types",
|
||||
"@gemmlowp",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
@ -401,6 +406,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":types",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
||||
@ -414,6 +420,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":types",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
||||
|
@ -34,12 +34,14 @@ limitations under the License.
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "fixedpoint/fixedpoint.h"
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
@ -5958,8 +5960,9 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
|
||||
|
||||
template <typename T>
|
||||
inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const RuntimeShape& input_shape,
|
||||
const RuntimeShape& output_shape,
|
||||
SequentialTensorWriter<T>* writer) {
|
||||
gemmlowp::ScopedProfilingLabel label("Slice");
|
||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
||||
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
|
||||
@ -5985,20 +5988,32 @@ inline void Slice(const tflite::SliceParams& op_params,
|
||||
? ext_shape.Dims(3) - start_d
|
||||
: start_d + op_params.size[size_count - 1];
|
||||
|
||||
T* out_ptr = output_data;
|
||||
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
||||
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
||||
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
||||
const int len = stop_d - start_d;
|
||||
memcpy(out_ptr,
|
||||
input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
|
||||
len * sizeof(T));
|
||||
out_ptr += len;
|
||||
writer->WriteN(Offset(ext_shape, in_b, in_h, in_w, start_d), len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||
return Slice(op_params, input_shape, output_shape, &writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& input_shape, const TfLiteTensor* input,
|
||||
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||
SequentialTensorWriter<T> writer(input, output);
|
||||
return Slice(op_params, input_shape, output_shape, &writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
||||
const T* input2_data, const RuntimeShape& output_shape,
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
|
||||
#include "fixedpoint/fixedpoint.h"
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
@ -34,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -3246,6 +3248,7 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
|
||||
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
|
||||
output_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
@ -3301,8 +3304,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
|
||||
template <typename T>
|
||||
inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const RuntimeShape& input_shape,
|
||||
const RuntimeShape& output_shape,
|
||||
SequentialTensorWriter<T>* writer) {
|
||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
||||
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
|
||||
TFLITE_DCHECK_LE(op_params.begin_count, 4);
|
||||
@ -3327,18 +3331,33 @@ inline void Slice(const tflite::SliceParams& op_params,
|
||||
? ext_shape.Dims(3) - start_d
|
||||
: start_d + op_params.size[size_count - 1];
|
||||
|
||||
T* out_ptr = output_data;
|
||||
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
||||
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
||||
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
||||
for (int in_d = start_d; in_d < stop_d; ++in_d) {
|
||||
*out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)];
|
||||
writer->Write(Offset(ext_shape, in_b, in_h, in_w, in_d));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||
return Slice(op_params, input_shape, output_shape, &writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& input_shape, const TfLiteTensor* input,
|
||||
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||
SequentialTensorWriter<T> writer(input, output);
|
||||
return Slice(op_params, input_shape, output_shape, &writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void Exp(const T* input_data, const size_t num_elements,
|
||||
T* output_data) {
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@ -109,6 +110,48 @@ class VectorOfQuantizedTensors : public VectorOfTensors<uint8> {
|
||||
std::vector<float> scale_;
|
||||
};
|
||||
|
||||
// Writes randomly accessed values from `input` sequentially into `output`.
|
||||
template <typename T>
|
||||
class SequentialTensorWriter {
|
||||
public:
|
||||
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) {
|
||||
input_data_ = GetTensorData<T>(input);
|
||||
output_ptr_ = GetTensorData<T>(output);
|
||||
}
|
||||
SequentialTensorWriter(const T* input_data, T* output_data)
|
||||
: input_data_(input_data), output_ptr_(output_data) {}
|
||||
|
||||
void Write(int position) { *output_ptr_++ = input_data_[position]; }
|
||||
void WriteN(int position, int len) {
|
||||
memcpy(output_ptr_, &input_data_[position], sizeof(T) * len);
|
||||
output_ptr_ += len;
|
||||
}
|
||||
|
||||
private:
|
||||
const T* input_data_;
|
||||
T* output_ptr_;
|
||||
};
|
||||
|
||||
template <>
|
||||
class SequentialTensorWriter<string> {
|
||||
public:
|
||||
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output)
|
||||
: input_(input), output_(output) {}
|
||||
~SequentialTensorWriter() { buffer_.WriteToTensor(output_, nullptr); }
|
||||
|
||||
void Write(int position) { this->WriteN(position, 1); }
|
||||
void WriteN(int position, int len) {
|
||||
for (int i = 0; i < len; i++) {
|
||||
buffer_.AddString(GetString(input_, position + i));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const TfLiteTensor* input_;
|
||||
TfLiteTensor* output_;
|
||||
DynamicBuffer buffer_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_
|
||||
|
@ -324,8 +324,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 3);
|
||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
||||
|
@ -172,27 +172,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
||||
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
||||
// the needed reversing.
|
||||
#define TF_LITE_SLICE(data_type, kernel_type) \
|
||||
{ \
|
||||
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
|
||||
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
|
||||
tflite::SliceParams op_params; \
|
||||
op_params.begin_count = 4; \
|
||||
op_params.size_count = 4; \
|
||||
for (int i = 0; i < 4; ++i) { \
|
||||
op_params.begin[i] = begins[3 - i]; \
|
||||
op_params.size[i] = sizes[3 - i]; \
|
||||
} \
|
||||
\
|
||||
if (kernel_type == kGenericOptimized) { \
|
||||
optimized_ops::Slice<data_type>( \
|
||||
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
|
||||
GetTensorShape(output), GetTensorData<data_type>(output)); \
|
||||
} else { \
|
||||
reference_ops::Slice<data_type>( \
|
||||
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
|
||||
GetTensorShape(output), GetTensorData<data_type>(output)); \
|
||||
} \
|
||||
#define TF_LITE_SLICE(data_type, kernel_type) \
|
||||
{ \
|
||||
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
|
||||
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
|
||||
tflite::SliceParams op_params; \
|
||||
op_params.begin_count = 4; \
|
||||
op_params.size_count = 4; \
|
||||
for (int i = 0; i < 4; ++i) { \
|
||||
op_params.begin[i] = begins[3 - i]; \
|
||||
op_params.size[i] = sizes[3 - i]; \
|
||||
} \
|
||||
\
|
||||
if (kernel_type == kGenericOptimized) { \
|
||||
optimized_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
|
||||
GetTensorShape(output), output); \
|
||||
} else { \
|
||||
reference_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
|
||||
GetTensorShape(output), output); \
|
||||
} \
|
||||
}
|
||||
|
||||
switch (input->type) {
|
||||
@ -214,6 +212,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteBool:
|
||||
TF_LITE_SLICE(bool, kernel_type);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
TF_LITE_SLICE(string, kernel_type);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Type %d is currently not supported by Slice.", input->type);
|
||||
|
@ -42,6 +42,9 @@ class SliceOpModel : public SingleOpModel {
|
||||
void SetInput(std::initializer_list<input_type> data) {
|
||||
PopulateTensor<input_type>(input_, data);
|
||||
}
|
||||
void SetStringInput(std::vector<string> data) {
|
||||
PopulateStringTensor(input_, data);
|
||||
}
|
||||
void SetBegin(std::initializer_list<index_type> data) {
|
||||
PopulateTensor<index_type>(begin_, data);
|
||||
}
|
||||
@ -185,6 +188,24 @@ TEST(SliceOpTest, SliceInt8) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, SliceString) {
|
||||
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_STRING);
|
||||
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
|
||||
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
|
||||
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
|
||||
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
|
||||
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
|
||||
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({2, 1, -1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({"1,0,0,0", "1,0,1,0", "1,0,2,0", //
|
||||
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -3744,7 +3744,7 @@ def make_slice_tests(options):
|
||||
test_parameters = [
|
||||
# 4-D
|
||||
{
|
||||
"dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||
"index_type": [tf.int32, tf.int64],
|
||||
"input_shape": [[12, 2, 2, 5]],
|
||||
"begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
|
||||
@ -3752,7 +3752,7 @@ def make_slice_tests(options):
|
||||
},
|
||||
# 2-D
|
||||
{
|
||||
"dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||
"index_type": [tf.int32, tf.int64],
|
||||
"input_shape": [[2, 3]],
|
||||
"begin": [[0, 0], [1, 0]],
|
||||
@ -3795,7 +3795,7 @@ def make_slice_tests(options):
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
expected_tf_failures=18)
|
||||
expected_tf_failures=24)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
|
@ -1709,10 +1709,14 @@ class Slice : public SimpleOperator<SliceOperator> {
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// Version 2 supports signed int8 input types.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
// Version 2 supports signed int8 input types.
|
||||
return 2;
|
||||
}
|
||||
if (input_array.data_type == ArrayDataType::kString) {
|
||||
// Version 3 supports string input types.
|
||||
return 3;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
@ -817,6 +817,18 @@ TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
|
||||
|
||||
TEST_F(OperatorTest, VersioningSliceTest) {
|
||||
SimpleVersioningTest<SliceOperator>();
|
||||
|
||||
// Check that a string input results in a version 3 op.
|
||||
SliceOperator op;
|
||||
op.inputs = {"input1"};
|
||||
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
||||
|
||||
Model string_model;
|
||||
Array& string_array = string_model.GetOrCreateArray(op.inputs[0]);
|
||||
string_array.data_type = ArrayDataType::kString;
|
||||
OperatorSignature string_signature = {.op = &op, .model = &string_model};
|
||||
EXPECT_EQ(base_op->GetVersion(string_signature), 3);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningLogisticTest) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user