Add string support to slice.

PiperOrigin-RevId: 242045915
This commit is contained in:
A. Unique TensorFlower 2019-04-04 17:55:17 -07:00 committed by TensorFlower Gardener
parent 51a3acd6a0
commit 0c4d4090c9
10 changed files with 161 additions and 38 deletions

View File

@ -189,6 +189,7 @@ cc_library(
":types", ":types",
":reference_base", ":reference_base",
":round", ":round",
":tensor",
":tensor_utils", ":tensor_utils",
"//third_party/eigen3", "//third_party/eigen3",
"@gemmlowp", "@gemmlowp",
@ -222,6 +223,7 @@ cc_library(
deps = [ deps = [
":quantization_util", ":quantization_util",
":strided_slice_logic", ":strided_slice_logic",
":tensor",
":tensor_utils", ":tensor_utils",
":types", ":types",
":legacy_types", ":legacy_types",
@ -258,6 +260,7 @@ cc_library(
":tensor", ":tensor",
":types", ":types",
"//tensorflow/core/kernels:eigen_spatial_convolutions-inl", "//tensorflow/core/kernels:eigen_spatial_convolutions-inl",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
"//third_party/eigen3", "//third_party/eigen3",
], ],
@ -341,6 +344,7 @@ cc_library(
":quantization_util", ":quantization_util",
":round", ":round",
":strided_slice_logic", ":strided_slice_logic",
":tensor",
":types", ":types",
"@gemmlowp", "@gemmlowp",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
@ -376,6 +380,7 @@ cc_library(
":round", ":round",
":strided_slice_logic", ":strided_slice_logic",
":legacy_types", ":legacy_types",
":tensor",
":types", ":types",
"@gemmlowp", "@gemmlowp",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
@ -401,6 +406,7 @@ cc_library(
], ],
deps = [ deps = [
":types", ":types",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
], ],
) )
@ -414,6 +420,7 @@ cc_library(
], ],
deps = [ deps = [
":types", ":types",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
], ],
) )

View File

@ -34,12 +34,14 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "fixedpoint/fixedpoint.h" #include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h" #include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h" #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/internal/types.h"
@ -5958,8 +5960,9 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
template <typename T> template <typename T>
inline void Slice(const tflite::SliceParams& op_params, inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& input_shape,
const RuntimeShape& output_shape, T* output_data) { const RuntimeShape& output_shape,
SequentialTensorWriter<T>* writer) {
gemmlowp::ScopedProfilingLabel label("Slice"); gemmlowp::ScopedProfilingLabel label("Slice");
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape); const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller. // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
@ -5985,20 +5988,32 @@ inline void Slice(const tflite::SliceParams& op_params,
? ext_shape.Dims(3) - start_d ? ext_shape.Dims(3) - start_d
: start_d + op_params.size[size_count - 1]; : start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) { for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) { for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) { for (int in_w = start_w; in_w < stop_w; ++in_w) {
const int len = stop_d - start_d; const int len = stop_d - start_d;
memcpy(out_ptr, writer->WriteN(Offset(ext_shape, in_b, in_h, in_w, start_d), len);
input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
len * sizeof(T));
out_ptr += len;
} }
} }
} }
} }
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const TfLiteTensor* input,
const RuntimeShape& output_shape, TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T> template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data, void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape, const T* input2_data, const RuntimeShape& output_shape,

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "fixedpoint/fixedpoint.h" #include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h" #include "public/gemmlowp.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/conv.h"
@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/softmax.h" #include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/round.h" #include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/internal/types.h"
namespace tflite { namespace tflite {
@ -3246,6 +3248,7 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape, Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data); output_data);
} }
template <typename T> template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params, inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape, const RuntimeShape& unextended_input_shape,
@ -3301,8 +3304,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
template <typename T> template <typename T>
inline void Slice(const tflite::SliceParams& op_params, inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& input_shape,
const RuntimeShape& output_shape, T* output_data) { const RuntimeShape& output_shape,
SequentialTensorWriter<T>* writer) {
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape); const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller. // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4); TFLITE_DCHECK_LE(op_params.begin_count, 4);
@ -3327,18 +3331,33 @@ inline void Slice(const tflite::SliceParams& op_params,
? ext_shape.Dims(3) - start_d ? ext_shape.Dims(3) - start_d
: start_d + op_params.size[size_count - 1]; : start_d + op_params.size[size_count - 1];
T* out_ptr = output_data;
for (int in_b = start_b; in_b < stop_b; ++in_b) { for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) { for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) { for (int in_w = start_w; in_w < stop_w; ++in_w) {
for (int in_d = start_d; in_d < stop_d; ++in_d) { for (int in_d = start_d; in_d < stop_d; ++in_d) {
*out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)]; writer->Write(Offset(ext_shape, in_b, in_h, in_w, in_d));
} }
} }
} }
} }
} }
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const TfLiteTensor* input,
const RuntimeShape& output_shape, TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T> template <typename T>
inline void Exp(const T* input_data, const size_t num_elements, inline void Exp(const T* input_data, const size_t num_elements,
T* output_data) { T* output_data) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/string_util.h"
namespace tflite { namespace tflite {
@ -109,6 +110,48 @@ class VectorOfQuantizedTensors : public VectorOfTensors<uint8> {
std::vector<float> scale_; std::vector<float> scale_;
}; };
// Writes randomly accessed values from `input` sequentially into `output`.
template <typename T>
class SequentialTensorWriter {
public:
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) {
input_data_ = GetTensorData<T>(input);
output_ptr_ = GetTensorData<T>(output);
}
SequentialTensorWriter(const T* input_data, T* output_data)
: input_data_(input_data), output_ptr_(output_data) {}
void Write(int position) { *output_ptr_++ = input_data_[position]; }
void WriteN(int position, int len) {
memcpy(output_ptr_, &input_data_[position], sizeof(T) * len);
output_ptr_ += len;
}
private:
const T* input_data_;
T* output_ptr_;
};
template <>
class SequentialTensorWriter<string> {
public:
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output)
: input_(input), output_(output) {}
~SequentialTensorWriter() { buffer_.WriteToTensor(output_, nullptr); }
void Write(int position) { this->WriteN(position, 1); }
void WriteN(int position, int len) {
for (int i = 0; i < len; i++) {
buffer_.AddString(GetString(input_, position + i));
}
}
private:
const TfLiteTensor* input_;
TfLiteTensor* output_;
DynamicBuffer buffer_;
};
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_ #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_

View File

@ -324,8 +324,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(), AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
/* min_version */ 1, /* min_version */ 1,
/* max_version */ 2); /* max_version */ 2);
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version */ 1, AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
/* max_version */ 2); /* min_version */ 1,
/* max_version */ 3);
AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());

View File

@ -172,27 +172,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// The dimensions in the kernel used to be in reverse-order, and TFLite // The dimensions in the kernel used to be in reverse-order, and TFLite
// arranged the begins and sizes vectors accordingly. This macro incorporates // arranged the begins and sizes vectors accordingly. This macro incorporates
// the needed reversing. // the needed reversing.
#define TF_LITE_SLICE(data_type, kernel_type) \ #define TF_LITE_SLICE(data_type, kernel_type) \
{ \ { \
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \ TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \ TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
tflite::SliceParams op_params; \ tflite::SliceParams op_params; \
op_params.begin_count = 4; \ op_params.begin_count = 4; \
op_params.size_count = 4; \ op_params.size_count = 4; \
for (int i = 0; i < 4; ++i) { \ for (int i = 0; i < 4; ++i) { \
op_params.begin[i] = begins[3 - i]; \ op_params.begin[i] = begins[3 - i]; \
op_params.size[i] = sizes[3 - i]; \ op_params.size[i] = sizes[3 - i]; \
} \ } \
\ \
if (kernel_type == kGenericOptimized) { \ if (kernel_type == kGenericOptimized) { \
optimized_ops::Slice<data_type>( \ optimized_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \ GetTensorShape(output), output); \
GetTensorShape(output), GetTensorData<data_type>(output)); \ } else { \
} else { \ reference_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
reference_ops::Slice<data_type>( \ GetTensorShape(output), output); \
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \ } \
GetTensorShape(output), GetTensorData<data_type>(output)); \
} \
} }
switch (input->type) { switch (input->type) {
@ -214,6 +212,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteBool: case kTfLiteBool:
TF_LITE_SLICE(bool, kernel_type); TF_LITE_SLICE(bool, kernel_type);
break; break;
case kTfLiteString:
TF_LITE_SLICE(string, kernel_type);
break;
default: default:
context->ReportError( context->ReportError(
context, "Type %d is currently not supported by Slice.", input->type); context, "Type %d is currently not supported by Slice.", input->type);

View File

@ -42,6 +42,9 @@ class SliceOpModel : public SingleOpModel {
void SetInput(std::initializer_list<input_type> data) { void SetInput(std::initializer_list<input_type> data) {
PopulateTensor<input_type>(input_, data); PopulateTensor<input_type>(input_, data);
} }
void SetStringInput(std::vector<string> data) {
PopulateStringTensor(input_, data);
}
void SetBegin(std::initializer_list<index_type> data) { void SetBegin(std::initializer_list<index_type> data) {
PopulateTensor<index_type>(begin_, data); PopulateTensor<index_type>(begin_, data);
} }
@ -185,6 +188,24 @@ TEST(SliceOpTest, SliceInt8) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
} }
TEST(SliceOpTest, SliceString) {
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_STRING);
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
m.SetBegin({1, 0, 0, 0});
m.SetSize({2, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({"1,0,0,0", "1,0,1,0", "1,0,2,0", //
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
}
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -3744,7 +3744,7 @@ def make_slice_tests(options):
test_parameters = [ test_parameters = [
# 4-D # 4-D
{ {
"dtype": [tf.float32, tf.int32, tf.int64], "dtype": [tf.float32, tf.int32, tf.int64, tf.string],
"index_type": [tf.int32, tf.int64], "index_type": [tf.int32, tf.int64],
"input_shape": [[12, 2, 2, 5]], "input_shape": [[12, 2, 2, 5]],
"begin": [[0, 0, 0, 0], [1, 0, 1, 0]], "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
@ -3752,7 +3752,7 @@ def make_slice_tests(options):
}, },
# 2-D # 2-D
{ {
"dtype": [tf.float32, tf.int32, tf.int64], "dtype": [tf.float32, tf.int32, tf.int64, tf.string],
"index_type": [tf.int32, tf.int64], "index_type": [tf.int32, tf.int64],
"input_shape": [[2, 3]], "input_shape": [[2, 3]],
"begin": [[0, 0], [1, 0]], "begin": [[0, 0], [1, 0]],
@ -3795,7 +3795,7 @@ def make_slice_tests(options):
test_parameters, test_parameters,
build_graph, build_graph,
build_inputs, build_inputs,
expected_tf_failures=18) expected_tf_failures=24)
@register_make_test_function() @register_make_test_function()

View File

@ -1709,10 +1709,14 @@ class Slice : public SimpleOperator<SliceOperator> {
int GetVersion(const OperatorSignature& op_signature) const override { int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0]; const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name); const Array& input_array = op_signature.model->GetArray(input_name);
// Version 2 supports signed int8 input types.
if (input_array.data_type == ArrayDataType::kInt8) { if (input_array.data_type == ArrayDataType::kInt8) {
// Version 2 supports signed int8 input types.
return 2; return 2;
} }
if (input_array.data_type == ArrayDataType::kString) {
// Version 3 supports string input types.
return 3;
}
return 1; return 1;
} }
}; };

View File

@ -817,6 +817,18 @@ TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
TEST_F(OperatorTest, VersioningSliceTest) { TEST_F(OperatorTest, VersioningSliceTest) {
SimpleVersioningTest<SliceOperator>(); SimpleVersioningTest<SliceOperator>();
// Check that a string input results in a version 3 op.
SliceOperator op;
op.inputs = {"input1"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
Model string_model;
Array& string_array = string_model.GetOrCreateArray(op.inputs[0]);
string_array.data_type = ArrayDataType::kString;
OperatorSignature string_signature = {.op = &op, .model = &string_model};
EXPECT_EQ(base_op->GetVersion(string_signature), 3);
} }
TEST_F(OperatorTest, VersioningLogisticTest) { TEST_F(OperatorTest, VersioningLogisticTest) {