[tf.lite] Add 5D support to slice
PiperOrigin-RevId: 349439300 Change-Id: I74c91acd4d230d30f6518469949aaa94dcd9045e
This commit is contained in:
parent
d00879d9ae
commit
7f20831cf7
@ -68,6 +68,7 @@
|
||||
* Both symmetric and asymmetric quantized input tensor are supported.
|
||||
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
|
||||
only supports float32 input.
|
||||
* Add 5D support to `SLICE` op.
|
||||
* TFLite Supports SingatureDef:
|
||||
* TFLiteConverter exports models with SignatureDef
|
||||
* Interpreter supports getting a list of signatures and getting callable
|
||||
|
@ -2085,7 +2085,7 @@ def TFL_SliceOp : TFL_Op<"slice", [
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRankAtMost<0, 5>,
|
||||
TFL_OperandHasRankAtMost<1, 1>,
|
||||
TFL_OperandHasRankAtMost<2, 1>]> {
|
||||
let summary = "Return a slice from 'input'.";
|
||||
|
@ -5402,36 +5402,32 @@ inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& output_shape,
|
||||
SequentialTensorWriter<T>* writer) {
|
||||
ruy::profiler::ScopeLabel label("Slice");
|
||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
||||
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
|
||||
TFLITE_DCHECK_LE(op_params.begin_count, 4);
|
||||
TFLITE_DCHECK_LE(op_params.size_count, 4);
|
||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
|
||||
TFLITE_DCHECK_LE(op_params.begin_count, 5);
|
||||
TFLITE_DCHECK_LE(op_params.size_count, 5);
|
||||
const int begin_count = op_params.begin_count;
|
||||
const int size_count = op_params.size_count;
|
||||
// We front-pad the begin and size vectors.
|
||||
const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
|
||||
const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
|
||||
? ext_shape.Dims(0)
|
||||
: start_b + op_params.size[0];
|
||||
const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
|
||||
const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
|
||||
? ext_shape.Dims(1)
|
||||
: start_h + op_params.size[size_count - 3];
|
||||
const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
|
||||
const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
|
||||
? ext_shape.Dims(2)
|
||||
: start_w + op_params.size[size_count - 2];
|
||||
const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
|
||||
const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
|
||||
? ext_shape.Dims(3)
|
||||
: start_d + op_params.size[size_count - 1];
|
||||
std::array<int, 5> start;
|
||||
std::array<int, 5> stop;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
int padded_i = 5 - i;
|
||||
start[i] =
|
||||
begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
|
||||
stop[i] =
|
||||
(size_count < padded_i || op_params.size[size_count - padded_i] == -1)
|
||||
? ext_shape.Dims(i)
|
||||
: start[i] + op_params.size[size_count - padded_i];
|
||||
}
|
||||
|
||||
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
||||
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
||||
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
||||
const int len = stop_d - start_d;
|
||||
if (len > 0)
|
||||
writer->WriteN(Offset(ext_shape, in_b, in_h, in_w, start_d), len);
|
||||
for (int i0 = start[0]; i0 < stop[0]; ++i0) {
|
||||
for (int i1 = start[1]; i1 < stop[1]; ++i1) {
|
||||
for (int i2 = start[2]; i2 < stop[2]; ++i2) {
|
||||
for (int i3 = start[3]; i3 < stop[3]; ++i3) {
|
||||
const int len = stop[4] - start[4];
|
||||
if (len > 0)
|
||||
writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1730,35 +1730,31 @@ inline void Slice(const tflite::SliceParams& op_params,
|
||||
const RuntimeShape& input_shape,
|
||||
const RuntimeShape& output_shape,
|
||||
SequentialTensorWriter<T>* writer) {
|
||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
|
||||
// TODO(b/174275841): This op only supports 4D tensors or smaller.
|
||||
TFLITE_DCHECK_LE(op_params.begin_count, 4);
|
||||
TFLITE_DCHECK_LE(op_params.size_count, 4);
|
||||
const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
|
||||
TFLITE_DCHECK_LE(op_params.begin_count, 5);
|
||||
TFLITE_DCHECK_LE(op_params.size_count, 5);
|
||||
const int begin_count = op_params.begin_count;
|
||||
const int size_count = op_params.size_count;
|
||||
// We front-pad the begin and size vectors.
|
||||
const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
|
||||
const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
|
||||
? ext_shape.Dims(0)
|
||||
: start_b + op_params.size[0];
|
||||
const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
|
||||
const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
|
||||
? ext_shape.Dims(1)
|
||||
: start_h + op_params.size[size_count - 3];
|
||||
const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
|
||||
const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
|
||||
? ext_shape.Dims(2)
|
||||
: start_w + op_params.size[size_count - 2];
|
||||
const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
|
||||
const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
|
||||
? ext_shape.Dims(3)
|
||||
: start_d + op_params.size[size_count - 1];
|
||||
std::array<int, 5> start;
|
||||
std::array<int, 5> stop;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
int padded_i = 5 - i;
|
||||
start[i] =
|
||||
begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
|
||||
stop[i] =
|
||||
(size_count < padded_i || op_params.size[size_count - padded_i] == -1)
|
||||
? ext_shape.Dims(i)
|
||||
: start[i] + op_params.size[size_count - padded_i];
|
||||
}
|
||||
|
||||
for (int in_b = start_b; in_b < stop_b; ++in_b) {
|
||||
for (int in_h = start_h; in_h < stop_h; ++in_h) {
|
||||
for (int in_w = start_w; in_w < stop_w; ++in_w) {
|
||||
for (int in_d = start_d; in_d < stop_d; ++in_d) {
|
||||
writer->Write(Offset(ext_shape, in_b, in_h, in_w, in_d));
|
||||
for (int i0 = start[0]; i0 < stop[0]; ++i0) {
|
||||
for (int i1 = start[1]; i1 < stop[1]; ++i1) {
|
||||
for (int i2 = start[2]; i2 < stop[2]; ++i2) {
|
||||
for (int i3 = start[3]; i3 < stop[3]; ++i3) {
|
||||
for (int i4 = start[4]; i4 < stop[4]; ++i4) {
|
||||
writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -392,6 +392,20 @@ inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
|
||||
return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
|
||||
}
|
||||
|
||||
inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
|
||||
int i4) {
|
||||
TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
|
||||
const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
|
||||
TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
|
||||
TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
|
||||
TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
|
||||
TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
|
||||
TFLITE_DCHECK(i4 >= 0 && i4 < dims_data[4]);
|
||||
return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
|
||||
dims_data[4] +
|
||||
i4;
|
||||
}
|
||||
|
||||
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
|
||||
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
|
||||
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
|
||||
@ -1025,9 +1039,9 @@ struct ResizeNearestNeighborParams {
|
||||
|
||||
struct SliceParams {
|
||||
int8_t begin_count;
|
||||
int32_t begin[4];
|
||||
int32_t begin[5];
|
||||
int8_t size_count;
|
||||
int32_t size[4];
|
||||
int32_t size[5];
|
||||
};
|
||||
|
||||
struct SoftmaxParams {
|
||||
|
@ -207,7 +207,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
|
||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),
|
||||
|
@ -366,7 +366,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
|
||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF(),
|
||||
|
@ -44,9 +44,9 @@ constexpr int kBeginTensor = 1;
|
||||
constexpr int kSizeTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// This Op only supports 1-4D cases and since we use the optimized ops 4D
|
||||
// implementation, the 1-3D tensors are mapped to 4D.
|
||||
const int kMaxDim = 4;
|
||||
// This Op only supports 1-5D cases and since we use the optimized ops 5D
|
||||
// implementation, the 1-4D tensors are mapped to 5D.
|
||||
const int kMaxDim = 5;
|
||||
|
||||
template <typename T>
|
||||
TfLiteStatus CalculateOutputShapeVector(TfLiteContext* context,
|
||||
@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(begin), NumElements(size));
|
||||
TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim,
|
||||
"Slice op only supports 1D-4D input arrays.");
|
||||
"Slice op only supports 1D-5D input arrays.");
|
||||
|
||||
// Postpone allocation of output if any of the indexing tensors is not
|
||||
// constant
|
||||
@ -184,20 +184,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// The original Slice op implementation only accepted 4-D sizes. That
|
||||
// constraint is, for the present, maintained here.
|
||||
// The Slice op implementation only accepts 5-D sizes. That constraint is, for
|
||||
// the present, maintained here.
|
||||
//
|
||||
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
||||
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
||||
// the needed reversing.
|
||||
#define TF_LITE_SLICE(data_type, kernel_type) \
|
||||
{ \
|
||||
TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
|
||||
TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
|
||||
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
|
||||
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
|
||||
tflite::SliceParams op_params; \
|
||||
op_params.begin_count = 4; \
|
||||
op_params.size_count = 4; \
|
||||
for (int i = 0; i < 4; ++i) { \
|
||||
op_params.begin_count = kMaxDim; \
|
||||
op_params.size_count = kMaxDim; \
|
||||
for (int i = 0; i < kMaxDim; ++i) { \
|
||||
op_params.begin[i] = begins[i]; \
|
||||
op_params.size[i] = sizes[i]; \
|
||||
} \
|
||||
|
@ -114,6 +114,16 @@ TEST_P(SliceOpTest, In3D) {
|
||||
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||
}
|
||||
|
||||
TEST_P(SliceOpTest, In5D) {
|
||||
SliceOpModel<float, int32_t> m({5, 1, 1, 1, 1}, {5}, {1, 0, 0, 0, 0}, {5},
|
||||
{3, 1, 1, 1, 1}, TensorType_INT32,
|
||||
TensorType_FLOAT32, GetParam());
|
||||
m.SetInput({1, 2, 3, 4, 5});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST_P(SliceOpTest, InputFloat) {
|
||||
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{3, 1, 1, 1}, TensorType_INT32,
|
||||
|
@ -41,6 +41,16 @@ def make_slice_tests(options):
|
||||
"constant_indices": [False],
|
||||
"fully_quantize": [False],
|
||||
},
|
||||
# 5-D
|
||||
{
|
||||
"dtype": [tf.float32],
|
||||
"index_type": [tf.int32],
|
||||
"input_shape": [[6, 2, 2, 2, 5]],
|
||||
"begin": [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0]],
|
||||
"size": [[4, 2, 2, 2, 3], [5, 2, 1, 1, 5]],
|
||||
"constant_indices": [False],
|
||||
"fully_quantize": [False],
|
||||
},
|
||||
# 2-D
|
||||
{
|
||||
"dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||
@ -156,9 +166,12 @@ def make_slice_tests(options):
|
||||
values = [input_values, begin_values, size_values]
|
||||
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
|
||||
|
||||
# Note: Not all [begin x size] permutations are compatible for each grouping
|
||||
# of test_parameters, but for brevity we ignore the failures rather than
|
||||
# separating out each compatible set into separate test_parameters entries.
|
||||
make_zip_of_tests(
|
||||
options,
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
expected_tf_failures=27)
|
||||
expected_tf_failures=29)
|
||||
|
@ -324,6 +324,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_SLICE:
|
||||
if (op_sig.options.single_input_op.num_dims > 4) {
|
||||
return 5;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT16) {
|
||||
return 4;
|
||||
}
|
||||
@ -796,6 +799,7 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||
} break;
|
||||
// TODO(b/150176627): Add tests for GetOpSignature.
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
case BuiltinOperator_SLICE:
|
||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
case BuiltinOperator_TRANSPOSE: {
|
||||
|
@ -226,24 +226,35 @@ TEST(OpVersionTest, VersioningSliceTest) {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT16},
|
||||
};
|
||||
fake_op_sig.options.single_input_op.num_dims = 5;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT16},
|
||||
};
|
||||
fake_op_sig.options.single_input_op.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_STRING},
|
||||
};
|
||||
fake_op_sig.options.single_input_op.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
fake_op_sig.options.single_input_op.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
fake_op_sig.options.single_input_op.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
|
@ -243,6 +243,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_SLICE, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_SLICE, 3}, "1.14.0"},
|
||||
{{BuiltinOperator_SLICE, 4}, "2.4.0"},
|
||||
{{BuiltinOperator_SLICE, 5}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_TANH, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_TANH, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_TANH, 3}, "2.3.0"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user