Add string support to slice.
PiperOrigin-RevId: 242045915
This commit is contained in:
parent
51a3acd6a0
commit
0c4d4090c9
@ -189,6 +189,7 @@ cc_library(
|
|||||||
":types",
|
":types",
|
||||||
":reference_base",
|
":reference_base",
|
||||||
":round",
|
":round",
|
||||||
|
":tensor",
|
||||||
":tensor_utils",
|
":tensor_utils",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@gemmlowp",
|
"@gemmlowp",
|
||||||
@ -222,6 +223,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":quantization_util",
|
":quantization_util",
|
||||||
":strided_slice_logic",
|
":strided_slice_logic",
|
||||||
|
":tensor",
|
||||||
":tensor_utils",
|
":tensor_utils",
|
||||||
":types",
|
":types",
|
||||||
":legacy_types",
|
":legacy_types",
|
||||||
@ -258,6 +260,7 @@ cc_library(
|
|||||||
":tensor",
|
":tensor",
|
||||||
":types",
|
":types",
|
||||||
"//tensorflow/core/kernels:eigen_spatial_convolutions-inl",
|
"//tensorflow/core/kernels:eigen_spatial_convolutions-inl",
|
||||||
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
@ -341,6 +344,7 @@ cc_library(
|
|||||||
":quantization_util",
|
":quantization_util",
|
||||||
":round",
|
":round",
|
||||||
":strided_slice_logic",
|
":strided_slice_logic",
|
||||||
|
":tensor",
|
||||||
":types",
|
":types",
|
||||||
"@gemmlowp",
|
"@gemmlowp",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
@ -376,6 +380,7 @@ cc_library(
|
|||||||
":round",
|
":round",
|
||||||
":strided_slice_logic",
|
":strided_slice_logic",
|
||||||
":legacy_types",
|
":legacy_types",
|
||||||
|
":tensor",
|
||||||
":types",
|
":types",
|
||||||
"@gemmlowp",
|
"@gemmlowp",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
@ -401,6 +406,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":types",
|
":types",
|
||||||
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -414,6 +420,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":types",
|
":types",
|
||||||
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -34,12 +34,14 @@ limitations under the License.
|
|||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "fixedpoint/fixedpoint.h"
|
#include "fixedpoint/fixedpoint.h"
|
||||||
#include "public/gemmlowp.h"
|
#include "public/gemmlowp.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
|
#include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/round.h"
|
#include "tensorflow/lite/kernels/internal/round.h"
|
||||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
@ -5958,8 +5960,9 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void Slice(const tflite::SliceParams& op_params,
|
inline void Slice(const tflite::SliceParams& op_params,
|
||||||
const RuntimeShape& input_shape, const T* input_data,
|
const RuntimeShape& input_shape,
|
||||||
const RuntimeShape& output_shape, T* output_data) {
|
const RuntimeShape& output_shape,
|
||||||
|
SequentialTensorWriter<T>* writer) {
|
||||||
gemmlowp::ScopedProfilingLabel label("Slice");
|
gemmlowp::ScopedProfilingLabel label("Slice");
|
||||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
||||||
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
|
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
|
||||||
@ -5985,20 +5988,32 @@ inline void Slice(const tflite::SliceParams& op_params,
|
|||||||
? ext_shape.Dims(3) - start_d
|
? ext_shape.Dims(3) - start_d
|
||||||
: start_d + op_params.size[size_count - 1];
|
: start_d + op_params.size[size_count - 1];
|
||||||
|
|
||||||
T* out_ptr = output_data;
|
|
||||||
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
||||||
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
||||||
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
||||||
const int len = stop_d - start_d;
|
const int len = stop_d - start_d;
|
||||||
memcpy(out_ptr,
|
writer->WriteN(Offset(ext_shape, in_b, in_h, in_w, start_d), len);
|
||||||
input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
|
|
||||||
len * sizeof(T));
|
|
||||||
out_ptr += len;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void Slice(const tflite::SliceParams& op_params,
|
||||||
|
const RuntimeShape& input_shape, const T* input_data,
|
||||||
|
const RuntimeShape& output_shape, T* output_data) {
|
||||||
|
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||||
|
return Slice(op_params, input_shape, output_shape, &writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void Slice(const tflite::SliceParams& op_params,
|
||||||
|
const RuntimeShape& input_shape, const TfLiteTensor* input,
|
||||||
|
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||||
|
SequentialTensorWriter<T> writer(input, output);
|
||||||
|
return Slice(op_params, input_shape, output_shape, &writer);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
||||||
const T* input2_data, const RuntimeShape& output_shape,
|
const T* input2_data, const RuntimeShape& output_shape,
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "fixedpoint/fixedpoint.h"
|
#include "fixedpoint/fixedpoint.h"
|
||||||
#include "public/gemmlowp.h"
|
#include "public/gemmlowp.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||||
@ -34,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||||
#include "tensorflow/lite/kernels/internal/round.h"
|
#include "tensorflow/lite/kernels/internal/round.h"
|
||||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -3246,6 +3248,7 @@ inline void PadImageStyle(const tflite::PadParams& op_params,
|
|||||||
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
|
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
|
||||||
output_data);
|
output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||||
const RuntimeShape& unextended_input_shape,
|
const RuntimeShape& unextended_input_shape,
|
||||||
@ -3301,8 +3304,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void Slice(const tflite::SliceParams& op_params,
|
inline void Slice(const tflite::SliceParams& op_params,
|
||||||
const RuntimeShape& input_shape, const T* input_data,
|
const RuntimeShape& input_shape,
|
||||||
const RuntimeShape& output_shape, T* output_data) {
|
const RuntimeShape& output_shape,
|
||||||
|
SequentialTensorWriter<T>* writer) {
|
||||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
||||||
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
|
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
|
||||||
TFLITE_DCHECK_LE(op_params.begin_count, 4);
|
TFLITE_DCHECK_LE(op_params.begin_count, 4);
|
||||||
@ -3327,18 +3331,33 @@ inline void Slice(const tflite::SliceParams& op_params,
|
|||||||
? ext_shape.Dims(3) - start_d
|
? ext_shape.Dims(3) - start_d
|
||||||
: start_d + op_params.size[size_count - 1];
|
: start_d + op_params.size[size_count - 1];
|
||||||
|
|
||||||
T* out_ptr = output_data;
|
|
||||||
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
||||||
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
||||||
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
||||||
for (int in_d = start_d; in_d < stop_d; ++in_d) {
|
for (int in_d = start_d; in_d < stop_d; ++in_d) {
|
||||||
*out_ptr++ = input_data[Offset(ext_shape, in_b, in_h, in_w, in_d)];
|
writer->Write(Offset(ext_shape, in_b, in_h, in_w, in_d));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void Slice(const tflite::SliceParams& op_params,
|
||||||
|
const RuntimeShape& input_shape, const T* input_data,
|
||||||
|
const RuntimeShape& output_shape, T* output_data) {
|
||||||
|
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||||
|
return Slice(op_params, input_shape, output_shape, &writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void Slice(const tflite::SliceParams& op_params,
|
||||||
|
const RuntimeShape& input_shape, const TfLiteTensor* input,
|
||||||
|
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||||
|
SequentialTensorWriter<T> writer(input, output);
|
||||||
|
return Slice(op_params, input_shape, output_shape, &writer);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void Exp(const T* input_data, const size_t num_elements,
|
inline void Exp(const T* input_data, const size_t num_elements,
|
||||||
T* output_data) {
|
T* output_data) {
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@ -109,6 +110,48 @@ class VectorOfQuantizedTensors : public VectorOfTensors<uint8> {
|
|||||||
std::vector<float> scale_;
|
std::vector<float> scale_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Writes randomly accessed values from `input` sequentially into `output`.
|
||||||
|
template <typename T>
|
||||||
|
class SequentialTensorWriter {
|
||||||
|
public:
|
||||||
|
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) {
|
||||||
|
input_data_ = GetTensorData<T>(input);
|
||||||
|
output_ptr_ = GetTensorData<T>(output);
|
||||||
|
}
|
||||||
|
SequentialTensorWriter(const T* input_data, T* output_data)
|
||||||
|
: input_data_(input_data), output_ptr_(output_data) {}
|
||||||
|
|
||||||
|
void Write(int position) { *output_ptr_++ = input_data_[position]; }
|
||||||
|
void WriteN(int position, int len) {
|
||||||
|
memcpy(output_ptr_, &input_data_[position], sizeof(T) * len);
|
||||||
|
output_ptr_ += len;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const T* input_data_;
|
||||||
|
T* output_ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class SequentialTensorWriter<string> {
|
||||||
|
public:
|
||||||
|
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output)
|
||||||
|
: input_(input), output_(output) {}
|
||||||
|
~SequentialTensorWriter() { buffer_.WriteToTensor(output_, nullptr); }
|
||||||
|
|
||||||
|
void Write(int position) { this->WriteN(position, 1); }
|
||||||
|
void WriteN(int position, int len) {
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
buffer_.AddString(GetString(input_, position + i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const TfLiteTensor* input_;
|
||||||
|
TfLiteTensor* output_;
|
||||||
|
DynamicBuffer buffer_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_
|
||||||
|
@ -324,8 +324,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
|
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version */ 1,
|
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
|
||||||
/* max_version */ 2);
|
/* min_version */ 1,
|
||||||
|
/* max_version */ 3);
|
||||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
||||||
|
@ -172,27 +172,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
||||||
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
||||||
// the needed reversing.
|
// the needed reversing.
|
||||||
#define TF_LITE_SLICE(data_type, kernel_type) \
|
#define TF_LITE_SLICE(data_type, kernel_type) \
|
||||||
{ \
|
{ \
|
||||||
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
|
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
|
||||||
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
|
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
|
||||||
tflite::SliceParams op_params; \
|
tflite::SliceParams op_params; \
|
||||||
op_params.begin_count = 4; \
|
op_params.begin_count = 4; \
|
||||||
op_params.size_count = 4; \
|
op_params.size_count = 4; \
|
||||||
for (int i = 0; i < 4; ++i) { \
|
for (int i = 0; i < 4; ++i) { \
|
||||||
op_params.begin[i] = begins[3 - i]; \
|
op_params.begin[i] = begins[3 - i]; \
|
||||||
op_params.size[i] = sizes[3 - i]; \
|
op_params.size[i] = sizes[3 - i]; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
if (kernel_type == kGenericOptimized) { \
|
if (kernel_type == kGenericOptimized) { \
|
||||||
optimized_ops::Slice<data_type>( \
|
optimized_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
|
||||||
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
|
GetTensorShape(output), output); \
|
||||||
GetTensorShape(output), GetTensorData<data_type>(output)); \
|
} else { \
|
||||||
} else { \
|
reference_ops::Slice<data_type>(op_params, GetTensorShape(input), input, \
|
||||||
reference_ops::Slice<data_type>( \
|
GetTensorShape(output), output); \
|
||||||
op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
|
} \
|
||||||
GetTensorShape(output), GetTensorData<data_type>(output)); \
|
|
||||||
} \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
@ -214,6 +212,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
TF_LITE_SLICE(bool, kernel_type);
|
TF_LITE_SLICE(bool, kernel_type);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteString:
|
||||||
|
TF_LITE_SLICE(string, kernel_type);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
context, "Type %d is currently not supported by Slice.", input->type);
|
context, "Type %d is currently not supported by Slice.", input->type);
|
||||||
|
@ -42,6 +42,9 @@ class SliceOpModel : public SingleOpModel {
|
|||||||
void SetInput(std::initializer_list<input_type> data) {
|
void SetInput(std::initializer_list<input_type> data) {
|
||||||
PopulateTensor<input_type>(input_, data);
|
PopulateTensor<input_type>(input_, data);
|
||||||
}
|
}
|
||||||
|
void SetStringInput(std::vector<string> data) {
|
||||||
|
PopulateStringTensor(input_, data);
|
||||||
|
}
|
||||||
void SetBegin(std::initializer_list<index_type> data) {
|
void SetBegin(std::initializer_list<index_type> data) {
|
||||||
PopulateTensor<index_type>(begin_, data);
|
PopulateTensor<index_type>(begin_, data);
|
||||||
}
|
}
|
||||||
@ -185,6 +188,24 @@ TEST(SliceOpTest, SliceInt8) {
|
|||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SliceOpTest, SliceString) {
|
||||||
|
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||||
|
TensorType_STRING);
|
||||||
|
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
|
||||||
|
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
|
||||||
|
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
|
||||||
|
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
|
||||||
|
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
|
||||||
|
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
|
||||||
|
m.SetBegin({1, 0, 0, 0});
|
||||||
|
m.SetSize({2, 1, -1, 1});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||||
|
EXPECT_THAT(m.GetOutput(),
|
||||||
|
ElementsAreArray({"1,0,0,0", "1,0,1,0", "1,0,2,0", //
|
||||||
|
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -3744,7 +3744,7 @@ def make_slice_tests(options):
|
|||||||
test_parameters = [
|
test_parameters = [
|
||||||
# 4-D
|
# 4-D
|
||||||
{
|
{
|
||||||
"dtype": [tf.float32, tf.int32, tf.int64],
|
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||||
"index_type": [tf.int32, tf.int64],
|
"index_type": [tf.int32, tf.int64],
|
||||||
"input_shape": [[12, 2, 2, 5]],
|
"input_shape": [[12, 2, 2, 5]],
|
||||||
"begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
|
"begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
|
||||||
@ -3752,7 +3752,7 @@ def make_slice_tests(options):
|
|||||||
},
|
},
|
||||||
# 2-D
|
# 2-D
|
||||||
{
|
{
|
||||||
"dtype": [tf.float32, tf.int32, tf.int64],
|
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||||
"index_type": [tf.int32, tf.int64],
|
"index_type": [tf.int32, tf.int64],
|
||||||
"input_shape": [[2, 3]],
|
"input_shape": [[2, 3]],
|
||||||
"begin": [[0, 0], [1, 0]],
|
"begin": [[0, 0], [1, 0]],
|
||||||
@ -3795,7 +3795,7 @@ def make_slice_tests(options):
|
|||||||
test_parameters,
|
test_parameters,
|
||||||
build_graph,
|
build_graph,
|
||||||
build_inputs,
|
build_inputs,
|
||||||
expected_tf_failures=18)
|
expected_tf_failures=24)
|
||||||
|
|
||||||
|
|
||||||
@register_make_test_function()
|
@register_make_test_function()
|
||||||
|
@ -1709,10 +1709,14 @@ class Slice : public SimpleOperator<SliceOperator> {
|
|||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
const string& input_name = op_signature.op->inputs[0];
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
// Version 2 supports signed int8 input types.
|
|
||||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
|
// Version 2 supports signed int8 input types.
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
if (input_array.data_type == ArrayDataType::kString) {
|
||||||
|
// Version 3 supports string input types.
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -817,6 +817,18 @@ TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
|
|||||||
|
|
||||||
TEST_F(OperatorTest, VersioningSliceTest) {
|
TEST_F(OperatorTest, VersioningSliceTest) {
|
||||||
SimpleVersioningTest<SliceOperator>();
|
SimpleVersioningTest<SliceOperator>();
|
||||||
|
|
||||||
|
// Check that a string input results in a version 3 op.
|
||||||
|
SliceOperator op;
|
||||||
|
op.inputs = {"input1"};
|
||||||
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||||
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
||||||
|
|
||||||
|
Model string_model;
|
||||||
|
Array& string_array = string_model.GetOrCreateArray(op.inputs[0]);
|
||||||
|
string_array.data_type = ArrayDataType::kString;
|
||||||
|
OperatorSignature string_signature = {.op = &op, .model = &string_model};
|
||||||
|
EXPECT_EQ(base_op->GetVersion(string_signature), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningLogisticTest) {
|
TEST_F(OperatorTest, VersioningLogisticTest) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user