Add 5D support to TFLite strided_slice

PiperOrigin-RevId: 297291699
Change-Id: Ib28177014451145a1bf62069899cb56e5b26756c
This commit is contained in:
Thai Nguyen 2020-02-25 23:29:26 -08:00 committed by TensorFlower Gardener
parent b05d05cba9
commit a920083a80
10 changed files with 116 additions and 48 deletions

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
@ -28,47 +27,60 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
// Note that the output_shape is not used herein.
tflite::StridedSliceParams params_copy = op_params;
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
RuntimeShape::ExtendedShape(5, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
RuntimeShape::ExtendedShape(5, unextended_output_shape);
// Reverse and pad to 4 dimensions because that is what the runtime code
// requires (ie. all shapes must be 4D and are given backwards).
strided_slice::StridedSlicePadIndices(&params_copy, 4);
// Reverse and pad to 5 dimensions because that is what the runtime code
// requires (ie. all shapes must be 5D and are given backwards).
strided_slice::StridedSlicePadIndices(&params_copy, 5);
const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
const int stop_b =
strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
const int stop_h =
strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
const int stop_w =
strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
const int stop_d =
strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
const int start_0 = StartForAxis(params_copy, input_shape, 0);
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
const int start_1 = StartForAxis(params_copy, input_shape, 1);
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
const int start_2 = StartForAxis(params_copy, input_shape, 2);
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
const int start_3 = StartForAxis(params_copy, input_shape, 3);
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
const int start_4 = StartForAxis(params_copy, input_shape, 4);
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
T* out_ptr = output_data;
for (int in_b = start_b;
!strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
in_b += params_copy.strides[0]) {
for (int in_h = start_h;
!strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
in_h += params_copy.strides[1]) {
for (int in_w = start_w;
!strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
in_w += params_copy.strides[2]) {
for (int in_d = start_d; !strided_slice::LoopCondition(
in_d, stop_d, params_copy.strides[3]);
in_d += params_copy.strides[3]) {
*out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
for (int offset_0 = start_0 * input_shape.Dims(1),
end_0 = stop_0 * input_shape.Dims(1),
step_0 = params_copy.strides[0] * input_shape.Dims(1);
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
offset_0 += step_0) {
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
step_1 = params_copy.strides[1] * input_shape.Dims(2);
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
offset_1 += step_1) {
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
step_2 = params_copy.strides[2] * input_shape.Dims(3);
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
offset_2 += step_2) {
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
step_3 = params_copy.strides[3] * input_shape.Dims(4);
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
offset_3 += step_3) {
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
offset_4 += params_copy.strides[4]) {
*out_ptr++ = input_data[offset_4];
}
}
}
}

View File

@ -35,7 +35,7 @@ inline int Clamp(const int v, const int lo, const int hi) {
inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
int dim_count) {
// Add indices and mask bits to fully include extra dimensions
TFLITE_CHECK_LE(dim_count, 4);
TFLITE_CHECK_LE(dim_count, 5);
TFLITE_CHECK_GE(dim_count, p->start_indices_count);
TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);

View File

@ -128,9 +128,9 @@ struct Dims {
class RuntimeShape {
public:
// Shapes with dimensions up to 4 are stored directly in the structure, while
// Shapes with dimensions up to 5 are stored directly in the structure, while
// larger shapes are separately allocated.
static constexpr int kMaxSmallSize = 4;
static constexpr int kMaxSmallSize = 5;
RuntimeShape& operator=(RuntimeShape const&) = delete;
@ -207,8 +207,8 @@ class RuntimeShape {
inline const int32* DimsData() const {
return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
}
// The caller must ensure that the shape is no bigger than 4-D.
inline const int32* DimsDataUpTo4D() const { return dims_; }
// The caller must ensure that the shape is no bigger than 5-D.
inline const int32* DimsDataUpTo5D() const { return dims_; }
inline void Resize(int dimensions_count) {
if (size_ > kMaxSmallSize) {
@ -378,7 +378,7 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo4D());
const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
@ -1049,11 +1049,11 @@ struct SqueezeParams {
struct StridedSliceParams {
int8 start_indices_count;
int32 start_indices[4];
int32 start_indices[5];
int8 stop_indices_count;
int32 stop_indices[4];
int32 stop_indices[5];
int8 strides_count;
int32 strides[4];
int32 strides[5];
int16 begin_mask;
int16 ellipsis_mask;

View File

@ -156,7 +156,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
/* min_version */ 1,
/* max_version */ 3);
/* max_version */ 4);
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
/* min_version */ 1,

View File

@ -142,8 +142,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
"StridedSlice op only supports 1D-4D input arrays.");
TF_LITE_ENSURE_MSG(context, op_context.dims <= 5,
"StridedSlice op only supports 1D-5D input arrays.");
// TODO(b/138098220): Remove when bug is resolved.
// Currently, working on using the compiler to cannonize strided_slice,

View File

@ -82,9 +82,9 @@ TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes);
#ifdef GTEST_HAS_DEATH_TEST
TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) {
EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0,
0, 0, 0, 0),
"StridedSlice op only supports 1D-4D input arrays.");
EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2, 2}, {5}, {5}, {5},
0, 0, 0, 0, 0),
"StridedSlice op only supports 1D-5D input arrays.");
}
TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) {
@ -612,5 +612,29 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(StridedSliceOpTest, In5D_Identity) {
StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
0);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBegin({0, 0, 0, 0, 0});
m.SetEnd({2, 1, 2, 1, 2});
m.SetStrides({1, 1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 9, 10, 11, 12}));
}
TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) {
StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBegin({0, 0, 0, 0, 0});
m.SetEnd({2, 1, 2, 1, 2});
m.SetStrides({1, 1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4}));
}
} // namespace
} // namespace tflite

View File

@ -68,6 +68,18 @@ def make_strided_slice_np_style_tests(options):
[slice(1, 11, 3), Ellipsis,
slice(3, 7, 2)]],
},
# Ellipsis 5d.
{
"dtype": [tf.float32],
"shape": [[11, 21, 15, 7, 9]],
"spec": [[
slice(3, 7, 2),
slice(None),
slice(None),
slice(None),
slice(None)
], [Ellipsis, slice(3, 7, 2)]],
},
# All combinations.
{
"dtype": [tf.float32],

View File

@ -1165,6 +1165,15 @@ class StridedSlice
op->new_axis_mask = options.new_axis_mask();
op->shrink_axis_mask = options.shrink_axis_mask();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& ss_op =
static_cast<const StridedSliceOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.strided_slice.num_dims = ss_op.start_indices.size();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,

View File

@ -266,6 +266,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 1;
case BuiltinOperator_STRIDED_SLICE:
if (op_sig.options.strided_slice.num_dims > 4) {
return 4;
}
// If the op takes bool input, it is version 3.
if (op_sig.input_types.at(0) == TensorType_BOOL) {
return 3;
@ -431,6 +434,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
resize_bilinear_option->half_pixel_centers();
}
} break;
// TODO(b/150176627): Add tests for GetOpSignature.
case BuiltinOperator_STRIDED_SLICE: {
op_sig.options.strided_slice.num_dims =
subgraph->tensors()->Get(op->inputs()->Get(0))->shape()->size();
} break;
default:
break;

View File

@ -49,6 +49,9 @@ typedef struct {
struct {
bool half_pixel_centers;
} resize_bilinear;
struct {
int32_t num_dims;
} strided_slice;
} options;
} OpSignature;