205 lines
6.7 KiB
C++
205 lines
6.7 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
|
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
|
|
|
|
#include <limits>
|
|
#include <vector>
|
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
|
#include "tensorflow/lite/kernels/internal/types.h"
|
|
|
|
namespace tflite {
|
|
namespace strided_slice {
|
|
|
|
// Use until std::clamp() is available from C++17.
|
|
inline int Clamp(const int v, const int lo, const int hi) {
|
|
TFLITE_DCHECK(!(hi < lo));
|
|
if (hi < v) return hi;
|
|
if (v < lo) return lo;
|
|
return v;
|
|
}
|
|
|
|
inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
|
|
int dim_count) {
|
|
// Add indices and mask bits to fully include extra dimensions
|
|
TFLITE_CHECK_LE(dim_count, 5);
|
|
TFLITE_CHECK_GE(dim_count, p->start_indices_count);
|
|
TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
|
|
TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
|
|
|
|
const int pad_count = dim_count - p->start_indices_count;
|
|
|
|
// Pad indices at start, so move arrays by pad_count.
|
|
for (int i = p->start_indices_count - 1; i >= 0; --i) {
|
|
p->strides[i + pad_count] = p->strides[i];
|
|
p->start_indices[i + pad_count] = p->start_indices[i];
|
|
p->stop_indices[i + pad_count] = p->stop_indices[i];
|
|
}
|
|
for (int i = 0; i < pad_count; ++i) {
|
|
p->start_indices[i] = 0;
|
|
p->stop_indices[i] = 1;
|
|
p->strides[i] = 1;
|
|
}
|
|
|
|
// Pad masks with 0s or 1s as required.
|
|
p->shrink_axis_mask <<= pad_count;
|
|
p->ellipsis_mask <<= pad_count;
|
|
p->new_axis_mask <<= pad_count;
|
|
p->begin_mask <<= pad_count;
|
|
p->end_mask <<= pad_count;
|
|
p->begin_mask |= (1 << pad_count) - 1;
|
|
p->end_mask |= (1 << pad_count) - 1;
|
|
|
|
p->start_indices_count = dim_count;
|
|
p->stop_indices_count = dim_count;
|
|
p->strides_count = dim_count;
|
|
}
|
|
|
|
// Return the index for the first element along that axis. This index will be a
|
|
// positive integer between [0, axis_size - 1] that can be used to index
|
|
// directly into the data.
|
|
inline int StartForAxis(const tflite::StridedSliceParams& params,
|
|
const RuntimeShape& input_shape, int axis) {
|
|
const auto begin_mask = params.begin_mask;
|
|
const auto* start_indices = params.start_indices;
|
|
const auto* strides = params.strides;
|
|
const int axis_size = input_shape.Dims(axis);
|
|
if (axis_size == 0) {
|
|
return 0;
|
|
}
|
|
// Begin with the specified index.
|
|
int start = start_indices[axis];
|
|
|
|
// begin_mask override
|
|
if (begin_mask & 1 << axis) {
|
|
if (strides[axis] > 0) {
|
|
// Forward iteration - use the first element. These values will get
|
|
// clamped below (Note: We could have set them to 0 and axis_size-1, but
|
|
// use lowest() and max() to maintain symmetry with StopForAxis())
|
|
start = std::numeric_limits<int>::lowest();
|
|
} else {
|
|
// Backward iteration - use the last element.
|
|
start = std::numeric_limits<int>::max();
|
|
}
|
|
}
|
|
|
|
// Handle negative indices
|
|
if (start < 0) {
|
|
start += axis_size;
|
|
}
|
|
|
|
// Clamping
|
|
start = Clamp(start, 0, axis_size - 1);
|
|
|
|
return start;
|
|
}
|
|
|
|
// Return the "real" index for the end of iteration along that axis. This is an
|
|
// "end" in the traditional C sense, in that it points to one past the last
|
|
// element. ie. So if you were iterating through all elements of a 1D array of
|
|
// size 4, this function would return 4 as the stop, because it is one past the
|
|
// "real" indices of 0, 1, 2 & 3.
|
|
inline int StopForAxis(const tflite::StridedSliceParams& params,
|
|
const RuntimeShape& input_shape, int axis,
|
|
int start_for_axis) {
|
|
const auto end_mask = params.end_mask;
|
|
const auto shrink_axis_mask = params.shrink_axis_mask;
|
|
const auto* stop_indices = params.stop_indices;
|
|
const auto* strides = params.strides;
|
|
const int axis_size = input_shape.Dims(axis);
|
|
if (axis_size == 0) {
|
|
return 0;
|
|
}
|
|
|
|
// Begin with the specified index
|
|
const bool shrink_axis = shrink_axis_mask & (1 << axis);
|
|
int stop = stop_indices[axis];
|
|
|
|
// When shrinking an axis, the end position does not matter (and can be
|
|
// incorrect when negative indexing is used, see Issue #19260). Always use
|
|
// start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
|
|
// already been adjusted for negative indices.
|
|
if (shrink_axis) {
|
|
stop = start_for_axis + 1;
|
|
}
|
|
|
|
// end_mask override
|
|
if (end_mask & (1 << axis)) {
|
|
if (strides[axis] > 0) {
|
|
// Forward iteration - use the last element. These values will get
|
|
// clamped below
|
|
stop = std::numeric_limits<int>::max();
|
|
} else {
|
|
// Backward iteration - use the first element.
|
|
stop = std::numeric_limits<int>::lowest();
|
|
}
|
|
}
|
|
|
|
// Handle negative indices
|
|
if (stop < 0) {
|
|
stop += axis_size;
|
|
}
|
|
|
|
// Clamping
|
|
// Because the end index points one past the last element, we need slightly
|
|
// different clamping ranges depending on the direction.
|
|
if (strides[axis] > 0) {
|
|
// Forward iteration
|
|
stop = Clamp(stop, 0, axis_size);
|
|
} else {
|
|
// Backward iteration
|
|
stop = Clamp(stop, -1, axis_size - 1);
|
|
}
|
|
|
|
return stop;
|
|
}
|
|
|
|
inline bool LoopCondition(int index, int stop, int stride) {
|
|
// True when we have reached the end of an axis and should loop.
|
|
return stride > 0 ? index >= stop : index <= stop;
|
|
}
|
|
|
|
inline tflite::StridedSliceParams BuildStridedSliceParams(
|
|
int begin_mask, int end_mask, int shrink_axis_mask,
|
|
const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
|
|
const std::vector<int>& strides) {
|
|
tflite::StridedSliceParams op_params;
|
|
const int dims_count = start_indices.size();
|
|
|
|
op_params.start_indices_count = dims_count;
|
|
op_params.stop_indices_count = dims_count;
|
|
op_params.strides_count = dims_count;
|
|
for (int i = 0; i < dims_count; ++i) {
|
|
op_params.start_indices[i] = start_indices[i];
|
|
op_params.stop_indices[i] = stop_indices[i];
|
|
op_params.strides[i] = strides[i];
|
|
}
|
|
|
|
op_params.begin_mask = begin_mask;
|
|
op_params.ellipsis_mask = 0;
|
|
op_params.end_mask = end_mask;
|
|
op_params.new_axis_mask = 0;
|
|
op_params.shrink_axis_mask = shrink_axis_mask;
|
|
|
|
return op_params;
|
|
}
|
|
|
|
} // namespace strided_slice
|
|
|
|
} // namespace tflite
|
|
|
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
|