[tf.lite] Add 5D support to slice

PiperOrigin-RevId: 349439300
Change-Id: I74c91acd4d230d30f6518469949aaa94dcd9045e
This commit is contained in:
Jared Duke 2020-12-29 10:19:23 -08:00 committed by TensorFlower Gardener
parent d00879d9ae
commit 7f20831cf7
13 changed files with 114 additions and 68 deletions

View File

@ -68,6 +68,7 @@
* Both symmetric and asymmetric quantized input tensor are supported.
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
only supports float32 input.
* Add 5D support to `SLICE` op.
* TFLite Supports SingatureDef:
* TFLiteConverter exports models with SignatureDef
* Interpreter supports getting a list of signatures and getting callable

View File

@ -2085,7 +2085,7 @@ def TFL_SliceOp : TFL_Op<"slice", [
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultsScale,
TFL_OperandHasRankAtMost<0, 4>,
TFL_OperandHasRankAtMost<0, 5>,
TFL_OperandHasRankAtMost<1, 1>,
TFL_OperandHasRankAtMost<2, 1>]> {
let summary = "Return a slice from 'input'.";

View File

@ -5402,36 +5402,32 @@ inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& output_shape,
SequentialTensorWriter<T>* writer) {
ruy::profiler::ScopeLabel label("Slice");
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4);
TFLITE_DCHECK_LE(op_params.size_count, 4);
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
TFLITE_DCHECK_LE(op_params.begin_count, 5);
TFLITE_DCHECK_LE(op_params.size_count, 5);
const int begin_count = op_params.begin_count;
const int size_count = op_params.size_count;
// We front-pad the begin and size vectors.
const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
? ext_shape.Dims(0)
: start_b + op_params.size[0];
const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
? ext_shape.Dims(1)
: start_h + op_params.size[size_count - 3];
const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
? ext_shape.Dims(2)
: start_w + op_params.size[size_count - 2];
const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
? ext_shape.Dims(3)
: start_d + op_params.size[size_count - 1];
std::array<int, 5> start;
std::array<int, 5> stop;
for (int i = 0; i < 5; ++i) {
int padded_i = 5 - i;
start[i] =
begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
stop[i] =
(size_count < padded_i || op_params.size[size_count - padded_i] == -1)
? ext_shape.Dims(i)
: start[i] + op_params.size[size_count - padded_i];
}
for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) {
const int len = stop_d - start_d;
if (len > 0)
writer->WriteN(Offset(ext_shape, in_b, in_h, in_w, start_d), len);
for (int i0 = start[0]; i0 < stop[0]; ++i0) {
for (int i1 = start[1]; i1 < stop[1]; ++i1) {
for (int i2 = start[2]; i2 < stop[2]; ++i2) {
for (int i3 = start[3]; i3 < stop[3]; ++i3) {
const int len = stop[4] - start[4];
if (len > 0)
writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
}
}
}
}

View File

@ -1730,35 +1730,31 @@ inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape,
const RuntimeShape& output_shape,
SequentialTensorWriter<T>* writer) {
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(b/174275841): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4);
TFLITE_DCHECK_LE(op_params.size_count, 4);
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
TFLITE_DCHECK_LE(op_params.begin_count, 5);
TFLITE_DCHECK_LE(op_params.size_count, 5);
const int begin_count = op_params.begin_count;
const int size_count = op_params.size_count;
// We front-pad the begin and size vectors.
const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
? ext_shape.Dims(0)
: start_b + op_params.size[0];
const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
? ext_shape.Dims(1)
: start_h + op_params.size[size_count - 3];
const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
? ext_shape.Dims(2)
: start_w + op_params.size[size_count - 2];
const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
? ext_shape.Dims(3)
: start_d + op_params.size[size_count - 1];
std::array<int, 5> start;
std::array<int, 5> stop;
for (int i = 0; i < 5; ++i) {
int padded_i = 5 - i;
start[i] =
begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
stop[i] =
(size_count < padded_i || op_params.size[size_count - padded_i] == -1)
? ext_shape.Dims(i)
: start[i] + op_params.size[size_count - padded_i];
}
for (int in_b = start_b; in_b < stop_b; ++in_b) {
for (int in_h = start_h; in_h < stop_h; ++in_h) {
for (int in_w = start_w; in_w < stop_w; ++in_w) {
for (int in_d = start_d; in_d < stop_d; ++in_d) {
writer->Write(Offset(ext_shape, in_b, in_h, in_w, in_d));
for (int i0 = start[0]; i0 < stop[0]; ++i0) {
for (int i1 = start[1]; i1 < stop[1]; ++i1) {
for (int i2 = start[2]; i2 < stop[2]; ++i2) {
for (int i3 = start[3]; i3 < stop[3]; ++i3) {
for (int i4 = start[4]; i4 < stop[4]; ++i4) {
writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
}
}
}
}

View File

@ -392,6 +392,20 @@ inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
}
inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
int i4) {
TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
TFLITE_DCHECK(i4 >= 0 && i4 < dims_data[4]);
return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
dims_data[4] +
i4;
}
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
@ -1025,9 +1039,9 @@ struct ResizeNearestNeighborParams {
struct SliceParams {
int8_t begin_count;
int32_t begin[4];
int32_t begin[5];
int8_t size_count;
int32_t size[4];
int32_t size[5];
};
struct SoftmaxParams {

View File

@ -207,7 +207,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
/* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),

View File

@ -366,7 +366,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF(),
/* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF(),

View File

@ -44,9 +44,9 @@ constexpr int kBeginTensor = 1;
constexpr int kSizeTensor = 2;
constexpr int kOutputTensor = 0;
// This Op only supports 1-4D cases and since we use the optimized ops 4D
// implementation, the 1-3D tensors are mapped to 4D.
const int kMaxDim = 4;
// This Op only supports 1-5D cases and since we use the optimized ops 5D
// implementation, the 1-4D tensors are mapped to 5D.
const int kMaxDim = 5;
template <typename T>
TfLiteStatus CalculateOutputShapeVector(TfLiteContext* context,
@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
TF_LITE_ENSURE_EQ(context, NumElements(begin), NumElements(size));
TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim,
"Slice op only supports 1D-4D input arrays.");
"Slice op only supports 1D-5D input arrays.");
// Postpone allocation of output if any of the indexing tensors is not
// constant
@ -184,20 +184,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
// The original Slice op implementation only accepted 4-D sizes. That
// constraint is, for the present, maintained here.
// The Slice op implementation only accepts 5-D sizes. That constraint is, for
// the present, maintained here.
//
// The dimensions in the kernel used to be in reverse-order, and TFLite
// arranged the begins and sizes vectors accordingly. This macro incorporates
// the needed reversing.
#define TF_LITE_SLICE(data_type, kernel_type) \
{ \
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
tflite::SliceParams op_params; \
op_params.begin_count = 4; \
op_params.size_count = 4; \
for (int i = 0; i < 4; ++i) { \
op_params.begin_count = kMaxDim; \
op_params.size_count = kMaxDim; \
for (int i = 0; i < kMaxDim; ++i) { \
op_params.begin[i] = begins[i]; \
op_params.size[i] = sizes[i]; \
} \

View File

@ -114,6 +114,16 @@ TEST_P(SliceOpTest, In3D) {
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
}
TEST_P(SliceOpTest, In5D) {
SliceOpModel<float, int32_t> m({5, 1, 1, 1, 1}, {5}, {1, 0, 0, 0, 0}, {5},
{3, 1, 1, 1, 1}, TensorType_INT32,
TensorType_FLOAT32, GetParam());
m.SetInput({1, 2, 3, 4, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
}
TEST_P(SliceOpTest, InputFloat) {
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
{3, 1, 1, 1}, TensorType_INT32,

View File

@ -41,6 +41,16 @@ def make_slice_tests(options):
"constant_indices": [False],
"fully_quantize": [False],
},
# 5-D
{
"dtype": [tf.float32],
"index_type": [tf.int32],
"input_shape": [[6, 2, 2, 2, 5]],
"begin": [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0]],
"size": [[4, 2, 2, 2, 3], [5, 2, 1, 1, 5]],
"constant_indices": [False],
"fully_quantize": [False],
},
# 2-D
{
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
@ -156,9 +166,12 @@ def make_slice_tests(options):
values = [input_values, begin_values, size_values]
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
# Note: Not all [begin x size] permutations are compatible for each grouping
# of test_parameters, but for brevity we ignore the failures rather than
# separating out each compatible set into separate test_parameters entries.
make_zip_of_tests(
options,
test_parameters,
build_graph,
build_inputs,
expected_tf_failures=27)
expected_tf_failures=29)

View File

@ -324,6 +324,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_SLICE:
if (op_sig.options.single_input_op.num_dims > 4) {
return 5;
}
if (op_sig.input_types.at(0) == TensorType_INT16) {
return 4;
}
@ -796,6 +799,7 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
} break;
// TODO(b/150176627): Add tests for GetOpSignature.
case BuiltinOperator_STRIDED_SLICE:
case BuiltinOperator_SLICE:
case BuiltinOperator_SPACE_TO_BATCH_ND:
case BuiltinOperator_BATCH_TO_SPACE_ND:
case BuiltinOperator_TRANSPOSE: {

View File

@ -226,24 +226,35 @@ TEST(OpVersionTest, VersioningSliceTest) {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_INT16},
};
fake_op_sig.options.single_input_op.num_dims = 5;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
fake_op_sig = {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_INT16},
};
fake_op_sig.options.single_input_op.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
fake_op_sig = {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_STRING},
};
fake_op_sig.options.single_input_op.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
fake_op_sig.options.single_input_op.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
fake_op_sig.options.single_input_op.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}

View File

@ -243,6 +243,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_SLICE, 2}, "1.14.0"},
{{BuiltinOperator_SLICE, 3}, "1.14.0"},
{{BuiltinOperator_SLICE, 4}, "2.4.0"},
{{BuiltinOperator_SLICE, 5}, kPendingReleaseVersion},
{{BuiltinOperator_TANH, 1}, "1.14.0"},
{{BuiltinOperator_TANH, 2}, "1.14.0"},
{{BuiltinOperator_TANH, 3}, "2.3.0"},