Add string support to slice.

PiperOrigin-RevId: 242045915
This commit is contained in:
A. Unique TensorFlower 2019-04-04 17:55:17 -07:00 committed by TensorFlower Gardener
parent 51a3acd6a0
commit 0c4d4090c9
10 changed files with 161 additions and 38 deletions

View File

@ -189,6 +189,7 @@ cc_library(
":types",
":reference_base",
":round",
":tensor",
":tensor_utils",
"//third_party/eigen3",
"@gemmlowp",
@ -222,6 +223,7 @@ cc_library(
deps = [
":quantization_util",
":strided_slice_logic",
":tensor",
":tensor_utils",
":types",
":legacy_types",
@ -258,6 +260,7 @@ cc_library(
":tensor",
":types",
"//tensorflow/core/kernels:eigen_spatial_convolutions-inl",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
"//third_party/eigen3",
],
@ -341,6 +344,7 @@ cc_library(
":quantization_util",
":round",
":strided_slice_logic",
":tensor",
":types",
"@gemmlowp",
"//tensorflow/lite/c:c_api_internal",
@ -376,6 +380,7 @@ cc_library(
":round",
":strided_slice_logic",
":legacy_types",
":tensor",
":types",
"@gemmlowp",
"//tensorflow/lite/c:c_api_internal",
@ -401,6 +406,7 @@ cc_library(
],
deps = [
":types",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
],
)
@ -414,6 +420,7 @@ cc_library(
],
deps = [
":types",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
],
)

View File

@ -34,12 +34,14 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
@ -5958,8 +5960,9 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
const RuntimeShape& input_shape,
const RuntimeShape& output_shape,
SequentialTensorWriter<T>* writer) {
gemmlowp::ScopedProfilingLabel label("Slice");
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
@ -5985,20 +5988,32 @@ inline void Slice(const tflite::SliceParams& op_params,
? ext_shape.Dims(3) - start_d
: start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) {
const int len = stop_d - start_d;
memcpy(out_ptr,
input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
len * sizeof(T));
out_ptr += len;
writer->WriteN(Offset(ext_shape, in_b, in_h, in_w, start_d), len);
}
}
}
}
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const TfLiteTensor* input,
const RuntimeShape& output_shape, TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -3246,6 +3248,7 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
@ -3301,8 +3304,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
const RuntimeShape& input_shape,
const RuntimeShape& output_shape,
SequentialTensorWriter<T>* writer) {
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4);
@ -3327,18 +3331,33 @@ inline void Slice(const tflite::SliceParams& op_params,
? ext_shape.Dims(3) - start_d
: start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) {
for (int in_d = start_d; in_d < stop_d; ++in_d) {
*out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)];
writer->Write(Offset(ext_shape, in_b, in_h, in_w, in_d));
}
}
}
}
}
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const TfLiteTensor* input,
const RuntimeShape& output_shape, TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T>
inline void Exp(const T* input_data, const size_t num_elements,
T* output_data) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
@ -109,6 +110,48 @@ class VectorOfQuantizedTensors : public VectorOfTensors<uint8> {
std::vector<float> scale_;
};
// Writes randomly accessed values from `input` sequentially into `output`.
template <typename T>
class SequentialTensorWriter {
public:
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) {
input_data_ = GetTensorData<T>(input);
output_ptr_ = GetTensorData<T>(output);
}
SequentialTensorWriter(const T* input_data, T* output_data)
: input_data_(input_data), output_ptr_(output_data) {}
void Write(int position) { *output_ptr_++ = input_data_[position]; }
void WriteN(int position, int len) {
memcpy(output_ptr_, &input_data_[position], sizeof(T) * len);
output_ptr_ += len;
}
private:
const T* input_data_;
T* output_ptr_;
};
template <>
class SequentialTensorWriter<string> {
public:
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output)
: input_(input), output_(output) {}
~SequentialTensorWriter() { buffer_.WriteToTensor(output_, nullptr); }
void Write(int position) { this->WriteN(position, 1); }
void WriteN(int position, int len) {
for (int i = 0; i < len; i++) {
buffer_.AddString(GetString(input_, position + i));
}
}
private:
const TfLiteTensor* input_;
TfLiteTensor* output_;
DynamicBuffer buffer_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_

View File

@ -324,8 +324,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
/* min_version */ 1,
/* max_version */ 3);
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());

View File

@ -172,27 +172,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// The dimensions in the kernel used to be in reverse-order, and TFLite
// arranged the begins and sizes vectors accordingly. This macro incorporates
// the needed reversing.
#define TF_LITE_SLICE(data_type, kernel_type) \
{ \
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
tflite::SliceParams op_params; \
op_params.begin_count = 4; \
op_params.size_count = 4; \
for (int i = 0; i < 4; ++i) { \
op_params.begin[i] = begins[3 - i]; \
op_params.size[i] = sizes[3 - i]; \
} \
\
if (kernel_type == kGenericOptimized) { \
optimized_ops::Slice<data_type>( \
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
GetTensorShape(output), GetTensorData<data_type>(output)); \
} else { \
reference_ops::Slice<data_type>( \
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
GetTensorShape(output), GetTensorData<data_type>(output)); \
} \
#define TF_LITE_SLICE(data_type, kernel_type) \
{ \
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
tflite::SliceParams op_params; \
op_params.begin_count = 4; \
op_params.size_count = 4; \
for (int i = 0; i < 4; ++i) { \
op_params.begin[i] = begins[3 - i]; \
op_params.size[i] = sizes[3 - i]; \
} \
\
if (kernel_type == kGenericOptimized) { \
optimized_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
GetTensorShape(output), output); \
} else { \
reference_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
GetTensorShape(output), output); \
} \
}
switch (input->type) {
@ -214,6 +212,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteBool:
TF_LITE_SLICE(bool, kernel_type);
break;
case kTfLiteString:
TF_LITE_SLICE(string, kernel_type);
break;
default:
context->ReportError(
context, "Type %d is currently not supported by Slice.", input->type);

View File

@ -42,6 +42,9 @@ class SliceOpModel : public SingleOpModel {
void SetInput(std::initializer_list<input_type> data) {
PopulateTensor<input_type>(input_, data);
}
void SetStringInput(std::vector<string> data) {
PopulateStringTensor(input_, data);
}
void SetBegin(std::initializer_list<index_type> data) {
PopulateTensor<index_type>(begin_, data);
}
@ -185,6 +188,24 @@ TEST(SliceOpTest, SliceInt8) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST(SliceOpTest, SliceString) {
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_STRING);
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
m.SetBegin({1, 0, 0, 0});
m.SetSize({2, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({"1,0,0,0", "1,0,1,0", "1,0,2,0", //
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
}
} // namespace
} // namespace tflite

View File

@ -3744,7 +3744,7 @@ def make_slice_tests(options):
test_parameters = [
# 4-D
{
"dtype": [tf.float32, tf.int32, tf.int64],
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
"index_type": [tf.int32, tf.int64],
"input_shape": [[12, 2, 2, 5]],
"begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
@ -3752,7 +3752,7 @@ def make_slice_tests(options):
},
# 2-D
{
"dtype": [tf.float32, tf.int32, tf.int64],
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
"index_type": [tf.int32, tf.int64],
"input_shape": [[2, 3]],
"begin": [[0, 0], [1, 0]],
@ -3795,7 +3795,7 @@ def make_slice_tests(options):
test_parameters,
build_graph,
build_inputs,
expected_tf_failures=18)
expected_tf_failures=24)
@register_make_test_function()

View File

@ -1709,10 +1709,14 @@ class Slice : public SimpleOperator<SliceOperator> {
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// Version 2 supports signed int8 input types.
if (input_array.data_type == ArrayDataType::kInt8) {
// Version 2 supports signed int8 input types.
return 2;
}
if (input_array.data_type == ArrayDataType::kString) {
// Version 3 supports string input types.
return 3;
}
return 1;
}
};

View File

@ -817,6 +817,18 @@ TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
TEST_F(OperatorTest, VersioningSliceTest) {
SimpleVersioningTest<SliceOperator>();
// Check that a string input results in a version 3 op.
SliceOperator op;
op.inputs = {"input1"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
Model string_model;
Array& string_array = string_model.GetOrCreateArray(op.inputs[0]);
string_array.data_type = ArrayDataType::kString;
OperatorSignature string_signature = {.op = &op, .model = &string_model};
EXPECT_EQ(base_op->GetVersion(string_signature), 3);
}
TEST_F(OperatorTest, VersioningLogisticTest) {