[tf.lite] Optimize strided_slice when inner stride == 1

PiperOrigin-RevId: 350407259
Change-Id: I8265097ad85aab0b0ff998f9dd465eb2ecf20183
This commit is contained in:
Jared Duke 2021-01-06 12:37:19 -08:00 committed by TensorFlower Gardener
parent 7d841e13c4
commit 335857da1b
5 changed files with 156 additions and 41 deletions

View File

@ -71,7 +71,6 @@ using reference_ops::ReluX;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
using reference_ops::Split;
using reference_ops::StridedSlice;
using reference_ops::TensorFlowSplit;
static constexpr int kDepthwiseReverseShift = -1;
@ -4949,6 +4948,23 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output,
output);
}
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
int begin_mask, int end_mask, int shrink_axis_mask,
const std::vector<int>& start_indices,
const std::vector<int>& stop_indices,
const std::vector<int>& strides, T* output_data,
const Dims<4>& output_dims) {
TFLITE_DCHECK_EQ(start_indices.size(), 4);
auto op_params = strided_slice::BuildStridedSliceParams(
begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
strides);
reference_ops::StridedSliceReverseIndices(&op_params);
StridedSlice(op_params, DimsToShape(input_dims), input_data,
DimsToShape(output_dims), output_data);
}
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_

View File

@ -104,7 +104,6 @@ using reference_ops::Round;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
using reference_ops::Split;
using reference_ops::StridedSlice;
using reference_ops::Sub16;
// TODO(b/80247582) Remove this constant.
@ -5449,6 +5448,106 @@ inline void Slice(const tflite::SliceParams& op_params,
return Slice(op_params, input_shape, output_shape, &writer);
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const RuntimeShape& unextended_output_shape,
SequentialTensorWriter<T>* writer) {
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
// We only have an optimized implementation for the case where the inner-most
// stride == 1. For all other cases, fall back to the reference impl.
if ((op_params.strides_count <= 0) ||
(op_params.strides[op_params.strides_count - 1] != 1)) {
reference_ops::StridedSlice(op_params, unextended_input_shape,
unextended_output_shape, writer);
return;
}
ruy::profiler::ScopeLabel label("StridedSliceInnerStrideOne");
// Note that the output_shape is not used herein.
tflite::StridedSliceParams params_copy = op_params;
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(5, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(5, unextended_output_shape);
// Reverse and pad to 5 dimensions because that is what the runtime code
// requires (ie. all shapes must be 5D and are given backwards).
strided_slice::StridedSlicePadIndices(&params_copy, 5);
TFLITE_DCHECK_EQ(params_copy.strides[4], 1);
const int start_0 = StartForAxis(params_copy, input_shape, 0);
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
const int start_1 = StartForAxis(params_copy, input_shape, 1);
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
const int start_2 = StartForAxis(params_copy, input_shape, 2);
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
const int start_3 = StartForAxis(params_copy, input_shape, 3);
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
const int start_4 = StartForAxis(params_copy, input_shape, 4);
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
for (int offset_0 = start_0 * input_shape.Dims(1),
end_0 = stop_0 * input_shape.Dims(1),
step_0 = params_copy.strides[0] * input_shape.Dims(1);
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
offset_0 += step_0) {
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
step_1 = params_copy.strides[1] * input_shape.Dims(2);
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
offset_1 += step_1) {
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
step_2 = params_copy.strides[2] * input_shape.Dims(3);
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
offset_2 += step_2) {
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
step_3 = params_copy.strides[3] * input_shape.Dims(4);
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
offset_3 += step_3) {
// Note: We've already validated that the inner-most stride is 1, so
// we can safely write the full inner sequence.
const int len = stop_4 - start_4;
if (len > 0) {
writer->WriteN(offset_3 + start_4, len);
}
}
}
}
}
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const TfLiteTensor* input,
const RuntimeShape& unextended_output_shape,
TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
@ -33,6 +34,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
ruy::profiler::ScopeLabel label("StridedSlice");
// Note that the output_shape is not used herein.
tflite::StridedSliceParams params_copy = op_params;

View File

@ -190,7 +190,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// The dimensions in the kernel used to be in reverse-order, and TFLite
// arranged the begins and sizes vectors accordingly. This macro incorporates
// the needed reversing.
#define TF_LITE_SLICE(data_type, kernel_type) \
#define TF_LITE_SLICE(data_type) \
{ \
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
@ -213,28 +213,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_SLICE(float, kernel_type);
TF_LITE_SLICE(float);
break;
case kTfLiteInt32:
TF_LITE_SLICE(int32_t, kernel_type);
TF_LITE_SLICE(int32_t);
break;
case kTfLiteInt64:
TF_LITE_SLICE(int64_t, kernel_type);
TF_LITE_SLICE(int64_t);
break;
case kTfLiteInt8:
TF_LITE_SLICE(int8_t, kernel_type);
TF_LITE_SLICE(int8_t);
break;
case kTfLiteInt16:
TF_LITE_SLICE(int16_t, kernel_type);
TF_LITE_SLICE(int16_t);
break;
case kTfLiteUInt8:
TF_LITE_SLICE(uint8_t, kernel_type);
TF_LITE_SLICE(uint8_t);
break;
case kTfLiteBool:
TF_LITE_SLICE(bool, kernel_type);
TF_LITE_SLICE(bool);
break;
case kTfLiteString:
TF_LITE_SLICE(string, kernel_type);
TF_LITE_SLICE(string);
break;
default:
context->ReportError(

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
@ -37,7 +38,7 @@ namespace strided_slice {
enum KernelType {
kReference,
// TODO(b/175642009): add kGenericOptimized
kGenericOptimized,
};
constexpr int kInputTensor = 0;
@ -190,51 +191,43 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
kernel_type::StridedSlice<data_type>( \
op_params, GetTensorShape(op_context.input), op_context.input, \
GetTensorShape(op_context.output), op_context.output)
#define TF_LITE_STRIDED_SLICE(data_type) \
{ \
if (kernel_type == kGenericOptimized) { \
optimized_ops::StridedSlice<data_type>( \
op_params, GetTensorShape(op_context.input), op_context.input, \
GetTensorShape(op_context.output), op_context.output); \
} else { \
reference_ops::StridedSlice<data_type>( \
op_params, GetTensorShape(op_context.input), op_context.input, \
GetTensorShape(op_context.output), op_context.output); \
} \
}
switch (op_context.input->type) {
case kTfLiteFloat32:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, float);
}
TF_LITE_STRIDED_SLICE(float);
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
}
TF_LITE_STRIDED_SLICE(int32_t);
break;
case kTfLiteInt64:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
}
TF_LITE_STRIDED_SLICE(int64_t);
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
}
TF_LITE_STRIDED_SLICE(uint8_t);
break;
case kTfLiteInt8:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
}
TF_LITE_STRIDED_SLICE(int8_t);
break;
case kTfLiteInt16:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, int16_t);
}
TF_LITE_STRIDED_SLICE(int16_t);
break;
case kTfLiteBool:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, bool);
}
TF_LITE_STRIDED_SLICE(bool);
break;
case kTfLiteString:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, string);
}
TF_LITE_STRIDED_SLICE(string);
break;
default:
TF_LITE_KERNEL_LOG(context,
@ -257,7 +250,10 @@ TfLiteRegistration* Register_STRIDED_SLICE_REF() {
}
TfLiteRegistration* Register_STRIDED_SLICE() {
return Register_STRIDED_SLICE_REF();
static TfLiteRegistration r = {
nullptr, nullptr, strided_slice::Prepare,
strided_slice::Eval<strided_slice::kGenericOptimized>};
return &r;
}
} // namespace builtin