Add 5D support to TFLite strided_slice
PiperOrigin-RevId: 297291699 Change-Id: Ib28177014451145a1bf62069899cb56e5b26756c
This commit is contained in:
parent
b05d05cba9
commit
a920083a80
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
@ -28,47 +27,60 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
using strided_slice::LoopCondition;
|
||||
using strided_slice::StartForAxis;
|
||||
using strided_slice::StopForAxis;
|
||||
// Note that the output_shape is not used herein.
|
||||
tflite::StridedSliceParams params_copy = op_params;
|
||||
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||
RuntimeShape::ExtendedShape(5, unextended_input_shape);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
RuntimeShape::ExtendedShape(5, unextended_output_shape);
|
||||
|
||||
// Reverse and pad to 4 dimensions because that is what the runtime code
|
||||
// requires (ie. all shapes must be 4D and are given backwards).
|
||||
strided_slice::StridedSlicePadIndices(¶ms_copy, 4);
|
||||
// Reverse and pad to 5 dimensions because that is what the runtime code
|
||||
// requires (ie. all shapes must be 5D and are given backwards).
|
||||
strided_slice::StridedSlicePadIndices(¶ms_copy, 5);
|
||||
|
||||
const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
|
||||
const int stop_b =
|
||||
strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
|
||||
const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
|
||||
const int stop_h =
|
||||
strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
|
||||
const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
|
||||
const int stop_w =
|
||||
strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
|
||||
const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
|
||||
const int stop_d =
|
||||
strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
|
||||
const int start_0 = StartForAxis(params_copy, input_shape, 0);
|
||||
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
|
||||
const int start_1 = StartForAxis(params_copy, input_shape, 1);
|
||||
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
|
||||
const int start_2 = StartForAxis(params_copy, input_shape, 2);
|
||||
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
|
||||
const int start_3 = StartForAxis(params_copy, input_shape, 3);
|
||||
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
|
||||
const int start_4 = StartForAxis(params_copy, input_shape, 4);
|
||||
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
|
||||
|
||||
T* out_ptr = output_data;
|
||||
for (int in_b = start_b;
|
||||
!strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
|
||||
in_b += params_copy.strides[0]) {
|
||||
for (int in_h = start_h;
|
||||
!strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
|
||||
in_h += params_copy.strides[1]) {
|
||||
for (int in_w = start_w;
|
||||
!strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
|
||||
in_w += params_copy.strides[2]) {
|
||||
for (int in_d = start_d; !strided_slice::LoopCondition(
|
||||
in_d, stop_d, params_copy.strides[3]);
|
||||
in_d += params_copy.strides[3]) {
|
||||
*out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
|
||||
for (int offset_0 = start_0 * input_shape.Dims(1),
|
||||
end_0 = stop_0 * input_shape.Dims(1),
|
||||
step_0 = params_copy.strides[0] * input_shape.Dims(1);
|
||||
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
|
||||
offset_0 += step_0) {
|
||||
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
|
||||
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
|
||||
step_1 = params_copy.strides[1] * input_shape.Dims(2);
|
||||
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
|
||||
offset_1 += step_1) {
|
||||
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
|
||||
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
|
||||
step_2 = params_copy.strides[2] * input_shape.Dims(3);
|
||||
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
|
||||
offset_2 += step_2) {
|
||||
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
|
||||
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
|
||||
step_3 = params_copy.strides[3] * input_shape.Dims(4);
|
||||
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
|
||||
offset_3 += step_3) {
|
||||
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
|
||||
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
|
||||
offset_4 += params_copy.strides[4]) {
|
||||
*out_ptr++ = input_data[offset_4];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ inline int Clamp(const int v, const int lo, const int hi) {
|
||||
inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
|
||||
int dim_count) {
|
||||
// Add indices and mask bits to fully include extra dimensions
|
||||
TFLITE_CHECK_LE(dim_count, 4);
|
||||
TFLITE_CHECK_LE(dim_count, 5);
|
||||
TFLITE_CHECK_GE(dim_count, p->start_indices_count);
|
||||
TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
|
||||
TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
|
||||
|
@ -128,9 +128,9 @@ struct Dims {
|
||||
|
||||
class RuntimeShape {
|
||||
public:
|
||||
// Shapes with dimensions up to 4 are stored directly in the structure, while
|
||||
// Shapes with dimensions up to 5 are stored directly in the structure, while
|
||||
// larger shapes are separately allocated.
|
||||
static constexpr int kMaxSmallSize = 4;
|
||||
static constexpr int kMaxSmallSize = 5;
|
||||
|
||||
RuntimeShape& operator=(RuntimeShape const&) = delete;
|
||||
|
||||
@ -207,8 +207,8 @@ class RuntimeShape {
|
||||
inline const int32* DimsData() const {
|
||||
return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
|
||||
}
|
||||
// The caller must ensure that the shape is no bigger than 4-D.
|
||||
inline const int32* DimsDataUpTo4D() const { return dims_; }
|
||||
// The caller must ensure that the shape is no bigger than 5-D.
|
||||
inline const int32* DimsDataUpTo5D() const { return dims_; }
|
||||
|
||||
inline void Resize(int dimensions_count) {
|
||||
if (size_ > kMaxSmallSize) {
|
||||
@ -378,7 +378,7 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
|
||||
|
||||
inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
|
||||
TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
|
||||
const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo4D());
|
||||
const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
|
||||
TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
|
||||
TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
|
||||
TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
|
||||
@ -1049,11 +1049,11 @@ struct SqueezeParams {
|
||||
|
||||
struct StridedSliceParams {
|
||||
int8 start_indices_count;
|
||||
int32 start_indices[4];
|
||||
int32 start_indices[5];
|
||||
int8 stop_indices_count;
|
||||
int32 stop_indices[4];
|
||||
int32 stop_indices[5];
|
||||
int8 strides_count;
|
||||
int32 strides[4];
|
||||
int32 strides[5];
|
||||
|
||||
int16 begin_mask;
|
||||
int16 ellipsis_mask;
|
||||
|
@ -156,7 +156,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 3);
|
||||
/* max_version */ 4);
|
||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
|
||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
|
||||
/* min_version */ 1,
|
||||
|
@ -142,8 +142,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
|
||||
"StridedSlice op only supports 1D-4D input arrays.");
|
||||
TF_LITE_ENSURE_MSG(context, op_context.dims <= 5,
|
||||
"StridedSlice op only supports 1D-5D input arrays.");
|
||||
|
||||
// TODO(b/138098220): Remove when bug is resolved.
|
||||
// Currently, working on using the compiler to cannonize strided_slice,
|
||||
|
@ -82,9 +82,9 @@ TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes);
|
||||
|
||||
#ifdef GTEST_HAS_DEATH_TEST
|
||||
TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) {
|
||||
EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0,
|
||||
0, 0, 0, 0),
|
||||
"StridedSlice op only supports 1D-4D input arrays.");
|
||||
EXPECT_DEATH(StridedSliceOpModel<TypeParam>({2, 2, 2, 2, 2, 2}, {5}, {5}, {5},
|
||||
0, 0, 0, 0, 0),
|
||||
"StridedSlice op only supports 1D-5D input arrays.");
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) {
|
||||
@ -612,5 +612,29 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In5D_Identity) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
|
||||
0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetBegin({0, 0, 0, 0, 0});
|
||||
m.SetEnd({2, 1, 2, 1, 2});
|
||||
m.SetStrides({1, 1, 1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 9, 10, 11, 12}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0,
|
||||
1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetBegin({0, 0, 0, 0, 0});
|
||||
m.SetEnd({2, 1, 2, 1, 2});
|
||||
m.SetStrides({1, 1, 1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -68,6 +68,18 @@ def make_strided_slice_np_style_tests(options):
|
||||
[slice(1, 11, 3), Ellipsis,
|
||||
slice(3, 7, 2)]],
|
||||
},
|
||||
# Ellipsis 5d.
|
||||
{
|
||||
"dtype": [tf.float32],
|
||||
"shape": [[11, 21, 15, 7, 9]],
|
||||
"spec": [[
|
||||
slice(3, 7, 2),
|
||||
slice(None),
|
||||
slice(None),
|
||||
slice(None),
|
||||
slice(None)
|
||||
], [Ellipsis, slice(3, 7, 2)]],
|
||||
},
|
||||
# All combinations.
|
||||
{
|
||||
"dtype": [tf.float32],
|
||||
|
@ -1165,6 +1165,15 @@ class StridedSlice
|
||||
op->new_axis_mask = options.new_axis_mask();
|
||||
op->shrink_axis_mask = options.shrink_axis_mask();
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const auto& ss_op =
|
||||
static_cast<const StridedSliceOperator&>(*op_signature.op);
|
||||
::tflite::OpSignature op_sig =
|
||||
GetVersioningOpSig(builtin_op(), op_signature);
|
||||
op_sig.options.strided_slice.num_dims = ss_op.start_indices.size();
|
||||
return ::tflite::GetBuiltinOperatorVersion(op_sig);
|
||||
}
|
||||
};
|
||||
|
||||
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
|
||||
|
@ -266,6 +266,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
}
|
||||
return 1;
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
if (op_sig.options.strided_slice.num_dims > 4) {
|
||||
return 4;
|
||||
}
|
||||
// If the op takes bool input, it is version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||
return 3;
|
||||
@ -431,6 +434,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||
resize_bilinear_option->half_pixel_centers();
|
||||
}
|
||||
} break;
|
||||
// TODO(b/150176627): Add tests for GetOpSignature.
|
||||
case BuiltinOperator_STRIDED_SLICE: {
|
||||
op_sig.options.strided_slice.num_dims =
|
||||
subgraph->tensors()->Get(op->inputs()->Get(0))->shape()->size();
|
||||
} break;
|
||||
|
||||
default:
|
||||
break;
|
||||
|
@ -49,6 +49,9 @@ typedef struct {
|
||||
struct {
|
||||
bool half_pixel_centers;
|
||||
} resize_bilinear;
|
||||
struct {
|
||||
int32_t num_dims;
|
||||
} strided_slice;
|
||||
} options;
|
||||
} OpSignature;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user