[tf.lite] Optimize strided_slice when inner stride == 1
PiperOrigin-RevId: 350407259 Change-Id: I8265097ad85aab0b0ff998f9dd465eb2ecf20183
This commit is contained in:
parent
7d841e13c4
commit
335857da1b
tensorflow/lite/kernels
@ -71,7 +71,6 @@ using reference_ops::ReluX;
|
||||
using reference_ops::Select;
|
||||
using reference_ops::SpaceToBatchND;
|
||||
using reference_ops::Split;
|
||||
using reference_ops::StridedSlice;
|
||||
using reference_ops::TensorFlowSplit;
|
||||
|
||||
static constexpr int kDepthwiseReverseShift = -1;
|
||||
@ -4949,6 +4948,23 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output,
|
||||
output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
|
||||
int begin_mask, int end_mask, int shrink_axis_mask,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& stop_indices,
|
||||
const std::vector<int>& strides, T* output_data,
|
||||
const Dims<4>& output_dims) {
|
||||
TFLITE_DCHECK_EQ(start_indices.size(), 4);
|
||||
auto op_params = strided_slice::BuildStridedSliceParams(
|
||||
begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
|
||||
strides);
|
||||
reference_ops::StridedSliceReverseIndices(&op_params);
|
||||
|
||||
StridedSlice(op_params, DimsToShape(input_dims), input_data,
|
||||
DimsToShape(output_dims), output_data);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
|
||||
|
@ -104,7 +104,6 @@ using reference_ops::Round;
|
||||
using reference_ops::Select;
|
||||
using reference_ops::SpaceToBatchND;
|
||||
using reference_ops::Split;
|
||||
using reference_ops::StridedSlice;
|
||||
using reference_ops::Sub16;
|
||||
|
||||
// TODO(b/80247582) Remove this constant.
|
||||
@ -5449,6 +5448,106 @@ inline void Slice(const tflite::SliceParams& op_params,
|
||||
return Slice(op_params, input_shape, output_shape, &writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
SequentialTensorWriter<T>* writer) {
|
||||
using strided_slice::LoopCondition;
|
||||
using strided_slice::StartForAxis;
|
||||
using strided_slice::StopForAxis;
|
||||
|
||||
// We only have an optimized implementation for the case where the inner-most
|
||||
// stride == 1. For all other cases, fall back to the reference impl.
|
||||
if ((op_params.strides_count <= 0) ||
|
||||
(op_params.strides[op_params.strides_count - 1] != 1)) {
|
||||
reference_ops::StridedSlice(op_params, unextended_input_shape,
|
||||
unextended_output_shape, writer);
|
||||
return;
|
||||
}
|
||||
|
||||
ruy::profiler::ScopeLabel label("StridedSliceInnerStrideOne");
|
||||
|
||||
// Note that the output_shape is not used herein.
|
||||
tflite::StridedSliceParams params_copy = op_params;
|
||||
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(5, unextended_input_shape);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(5, unextended_output_shape);
|
||||
|
||||
// Reverse and pad to 5 dimensions because that is what the runtime code
|
||||
// requires (ie. all shapes must be 5D and are given backwards).
|
||||
strided_slice::StridedSlicePadIndices(¶ms_copy, 5);
|
||||
TFLITE_DCHECK_EQ(params_copy.strides[4], 1);
|
||||
|
||||
const int start_0 = StartForAxis(params_copy, input_shape, 0);
|
||||
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
|
||||
const int start_1 = StartForAxis(params_copy, input_shape, 1);
|
||||
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
|
||||
const int start_2 = StartForAxis(params_copy, input_shape, 2);
|
||||
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
|
||||
const int start_3 = StartForAxis(params_copy, input_shape, 3);
|
||||
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
|
||||
const int start_4 = StartForAxis(params_copy, input_shape, 4);
|
||||
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
|
||||
|
||||
for (int offset_0 = start_0 * input_shape.Dims(1),
|
||||
end_0 = stop_0 * input_shape.Dims(1),
|
||||
step_0 = params_copy.strides[0] * input_shape.Dims(1);
|
||||
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
|
||||
offset_0 += step_0) {
|
||||
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
|
||||
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
|
||||
step_1 = params_copy.strides[1] * input_shape.Dims(2);
|
||||
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
|
||||
offset_1 += step_1) {
|
||||
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
|
||||
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
|
||||
step_2 = params_copy.strides[2] * input_shape.Dims(3);
|
||||
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
|
||||
offset_2 += step_2) {
|
||||
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
|
||||
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
|
||||
step_3 = params_copy.strides[3] * input_shape.Dims(4);
|
||||
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
|
||||
offset_3 += step_3) {
|
||||
// Note: We've already validated that the inner-most stride is 1, so
|
||||
// we can safely write the full inner sequence.
|
||||
const int len = stop_4 - start_4;
|
||||
if (len > 0) {
|
||||
writer->WriteN(offset_3 + start_4, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||
&writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const TfLiteTensor* input,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
TfLiteTensor* output) {
|
||||
SequentialTensorWriter<T> writer(input, output);
|
||||
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||
&writer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
||||
const T* input2_data, const RuntimeShape& output_shape,
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
||||
|
||||
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
||||
@ -33,6 +34,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||
using strided_slice::LoopCondition;
|
||||
using strided_slice::StartForAxis;
|
||||
using strided_slice::StopForAxis;
|
||||
|
||||
ruy::profiler::ScopeLabel label("StridedSlice");
|
||||
|
||||
// Note that the output_shape is not used herein.
|
||||
tflite::StridedSliceParams params_copy = op_params;
|
||||
|
||||
|
@ -190,7 +190,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
||||
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
||||
// the needed reversing.
|
||||
#define TF_LITE_SLICE(data_type, kernel_type) \
|
||||
#define TF_LITE_SLICE(data_type) \
|
||||
{ \
|
||||
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
|
||||
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
|
||||
@ -213,28 +213,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_SLICE(float, kernel_type);
|
||||
TF_LITE_SLICE(float);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_SLICE(int32_t, kernel_type);
|
||||
TF_LITE_SLICE(int32_t);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_SLICE(int64_t, kernel_type);
|
||||
TF_LITE_SLICE(int64_t);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_SLICE(int8_t, kernel_type);
|
||||
TF_LITE_SLICE(int8_t);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_SLICE(int16_t, kernel_type);
|
||||
TF_LITE_SLICE(int16_t);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_SLICE(uint8_t, kernel_type);
|
||||
TF_LITE_SLICE(uint8_t);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
TF_LITE_SLICE(bool, kernel_type);
|
||||
TF_LITE_SLICE(bool);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
TF_LITE_SLICE(string, kernel_type);
|
||||
TF_LITE_SLICE(string);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
@ -37,7 +38,7 @@ namespace strided_slice {
|
||||
|
||||
enum KernelType {
|
||||
kReference,
|
||||
// TODO(b/175642009): add kGenericOptimized
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
@ -190,51 +191,43 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
|
||||
|
||||
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
|
||||
kernel_type::StridedSlice<data_type>( \
|
||||
op_params, GetTensorShape(op_context.input), op_context.input, \
|
||||
GetTensorShape(op_context.output), op_context.output)
|
||||
#define TF_LITE_STRIDED_SLICE(data_type) \
|
||||
{ \
|
||||
if (kernel_type == kGenericOptimized) { \
|
||||
optimized_ops::StridedSlice<data_type>( \
|
||||
op_params, GetTensorShape(op_context.input), op_context.input, \
|
||||
GetTensorShape(op_context.output), op_context.output); \
|
||||
} else { \
|
||||
reference_ops::StridedSlice<data_type>( \
|
||||
op_params, GetTensorShape(op_context.input), op_context.input, \
|
||||
GetTensorShape(op_context.output), op_context.output); \
|
||||
} \
|
||||
}
|
||||
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, float);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(float);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(int32_t);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(int64_t);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(uint8_t);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(int8_t);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int16_t);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(int16_t);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, bool);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(bool);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, string);
|
||||
}
|
||||
TF_LITE_STRIDED_SLICE(string);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
@ -257,7 +250,10 @@ TfLiteRegistration* Register_STRIDED_SLICE_REF() {
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_STRIDED_SLICE() {
|
||||
return Register_STRIDED_SLICE_REF();
|
||||
static TfLiteRegistration r = {
|
||||
nullptr, nullptr, strided_slice::Prepare,
|
||||
strided_slice::Eval<strided_slice::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
|
Loading…
Reference in New Issue
Block a user