[tf.lite] Optimize strided_slice when inner stride == 1
PiperOrigin-RevId: 350407259 Change-Id: I8265097ad85aab0b0ff998f9dd465eb2ecf20183
This commit is contained in:
parent
7d841e13c4
commit
335857da1b
tensorflow/lite/kernels
@ -71,7 +71,6 @@ using reference_ops::ReluX;
|
|||||||
using reference_ops::Select;
|
using reference_ops::Select;
|
||||||
using reference_ops::SpaceToBatchND;
|
using reference_ops::SpaceToBatchND;
|
||||||
using reference_ops::Split;
|
using reference_ops::Split;
|
||||||
using reference_ops::StridedSlice;
|
|
||||||
using reference_ops::TensorFlowSplit;
|
using reference_ops::TensorFlowSplit;
|
||||||
|
|
||||||
static constexpr int kDepthwiseReverseShift = -1;
|
static constexpr int kDepthwiseReverseShift = -1;
|
||||||
@ -4949,6 +4948,23 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output,
|
|||||||
output);
|
output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
|
||||||
|
int begin_mask, int end_mask, int shrink_axis_mask,
|
||||||
|
const std::vector<int>& start_indices,
|
||||||
|
const std::vector<int>& stop_indices,
|
||||||
|
const std::vector<int>& strides, T* output_data,
|
||||||
|
const Dims<4>& output_dims) {
|
||||||
|
TFLITE_DCHECK_EQ(start_indices.size(), 4);
|
||||||
|
auto op_params = strided_slice::BuildStridedSliceParams(
|
||||||
|
begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
|
||||||
|
strides);
|
||||||
|
reference_ops::StridedSliceReverseIndices(&op_params);
|
||||||
|
|
||||||
|
StridedSlice(op_params, DimsToShape(input_dims), input_data,
|
||||||
|
DimsToShape(output_dims), output_data);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace optimized_ops
|
} // namespace optimized_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
|
||||||
|
@ -104,7 +104,6 @@ using reference_ops::Round;
|
|||||||
using reference_ops::Select;
|
using reference_ops::Select;
|
||||||
using reference_ops::SpaceToBatchND;
|
using reference_ops::SpaceToBatchND;
|
||||||
using reference_ops::Split;
|
using reference_ops::Split;
|
||||||
using reference_ops::StridedSlice;
|
|
||||||
using reference_ops::Sub16;
|
using reference_ops::Sub16;
|
||||||
|
|
||||||
// TODO(b/80247582) Remove this constant.
|
// TODO(b/80247582) Remove this constant.
|
||||||
@ -5449,6 +5448,106 @@ inline void Slice(const tflite::SliceParams& op_params,
|
|||||||
return Slice(op_params, input_shape, output_shape, &writer);
|
return Slice(op_params, input_shape, output_shape, &writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||||
|
const RuntimeShape& unextended_input_shape,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
SequentialTensorWriter<T>* writer) {
|
||||||
|
using strided_slice::LoopCondition;
|
||||||
|
using strided_slice::StartForAxis;
|
||||||
|
using strided_slice::StopForAxis;
|
||||||
|
|
||||||
|
// We only have an optimized implementation for the case where the inner-most
|
||||||
|
// stride == 1. For all other cases, fall back to the reference impl.
|
||||||
|
if ((op_params.strides_count <= 0) ||
|
||||||
|
(op_params.strides[op_params.strides_count - 1] != 1)) {
|
||||||
|
reference_ops::StridedSlice(op_params, unextended_input_shape,
|
||||||
|
unextended_output_shape, writer);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ruy::profiler::ScopeLabel label("StridedSliceInnerStrideOne");
|
||||||
|
|
||||||
|
// Note that the output_shape is not used herein.
|
||||||
|
tflite::StridedSliceParams params_copy = op_params;
|
||||||
|
|
||||||
|
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
|
||||||
|
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
|
||||||
|
const RuntimeShape input_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, unextended_input_shape);
|
||||||
|
const RuntimeShape output_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, unextended_output_shape);
|
||||||
|
|
||||||
|
// Reverse and pad to 5 dimensions because that is what the runtime code
|
||||||
|
// requires (ie. all shapes must be 5D and are given backwards).
|
||||||
|
strided_slice::StridedSlicePadIndices(¶ms_copy, 5);
|
||||||
|
TFLITE_DCHECK_EQ(params_copy.strides[4], 1);
|
||||||
|
|
||||||
|
const int start_0 = StartForAxis(params_copy, input_shape, 0);
|
||||||
|
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
|
||||||
|
const int start_1 = StartForAxis(params_copy, input_shape, 1);
|
||||||
|
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
|
||||||
|
const int start_2 = StartForAxis(params_copy, input_shape, 2);
|
||||||
|
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
|
||||||
|
const int start_3 = StartForAxis(params_copy, input_shape, 3);
|
||||||
|
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
|
||||||
|
const int start_4 = StartForAxis(params_copy, input_shape, 4);
|
||||||
|
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
|
||||||
|
|
||||||
|
for (int offset_0 = start_0 * input_shape.Dims(1),
|
||||||
|
end_0 = stop_0 * input_shape.Dims(1),
|
||||||
|
step_0 = params_copy.strides[0] * input_shape.Dims(1);
|
||||||
|
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
|
||||||
|
offset_0 += step_0) {
|
||||||
|
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
|
||||||
|
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
|
||||||
|
step_1 = params_copy.strides[1] * input_shape.Dims(2);
|
||||||
|
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
|
||||||
|
offset_1 += step_1) {
|
||||||
|
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
|
||||||
|
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
|
||||||
|
step_2 = params_copy.strides[2] * input_shape.Dims(3);
|
||||||
|
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
|
||||||
|
offset_2 += step_2) {
|
||||||
|
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
|
||||||
|
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
|
||||||
|
step_3 = params_copy.strides[3] * input_shape.Dims(4);
|
||||||
|
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
|
||||||
|
offset_3 += step_3) {
|
||||||
|
// Note: We've already validated that the inner-most stride is 1, so
|
||||||
|
// we can safely write the full inner sequence.
|
||||||
|
const int len = stop_4 - start_4;
|
||||||
|
if (len > 0) {
|
||||||
|
writer->WriteN(offset_3 + start_4, len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||||
|
const RuntimeShape& unextended_input_shape,
|
||||||
|
const T* input_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
SequentialTensorWriter<T> writer(input_data, output_data);
|
||||||
|
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||||
|
&writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
||||||
|
const RuntimeShape& unextended_input_shape,
|
||||||
|
const TfLiteTensor* input,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
TfLiteTensor* output) {
|
||||||
|
SequentialTensorWriter<T> writer(input, output);
|
||||||
|
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
|
||||||
|
&writer);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
|
||||||
const T* input2_data, const RuntimeShape& output_shape,
|
const T* input2_data, const RuntimeShape& output_shape,
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
|
||||||
|
|
||||||
|
#include "ruy/profiler/instrumentation.h" // from @ruy
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
||||||
@ -33,6 +34,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
|
|||||||
using strided_slice::LoopCondition;
|
using strided_slice::LoopCondition;
|
||||||
using strided_slice::StartForAxis;
|
using strided_slice::StartForAxis;
|
||||||
using strided_slice::StopForAxis;
|
using strided_slice::StopForAxis;
|
||||||
|
|
||||||
|
ruy::profiler::ScopeLabel label("StridedSlice");
|
||||||
|
|
||||||
// Note that the output_shape is not used herein.
|
// Note that the output_shape is not used herein.
|
||||||
tflite::StridedSliceParams params_copy = op_params;
|
tflite::StridedSliceParams params_copy = op_params;
|
||||||
|
|
||||||
|
@ -190,7 +190,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
// The dimensions in the kernel used to be in reverse-order, and TFLite
|
||||||
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
// arranged the begins and sizes vectors accordingly. This macro incorporates
|
||||||
// the needed reversing.
|
// the needed reversing.
|
||||||
#define TF_LITE_SLICE(data_type, kernel_type) \
|
#define TF_LITE_SLICE(data_type) \
|
||||||
{ \
|
{ \
|
||||||
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
|
TF_LITE_ENSURE_EQ(context, begins.size(), kMaxDim); \
|
||||||
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
|
TF_LITE_ENSURE_EQ(context, sizes.size(), kMaxDim); \
|
||||||
@ -213,28 +213,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
TF_LITE_SLICE(float, kernel_type);
|
TF_LITE_SLICE(float);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
TF_LITE_SLICE(int32_t, kernel_type);
|
TF_LITE_SLICE(int32_t);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
TF_LITE_SLICE(int64_t, kernel_type);
|
TF_LITE_SLICE(int64_t);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
TF_LITE_SLICE(int8_t, kernel_type);
|
TF_LITE_SLICE(int8_t);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt16:
|
case kTfLiteInt16:
|
||||||
TF_LITE_SLICE(int16_t, kernel_type);
|
TF_LITE_SLICE(int16_t);
|
||||||
break;
|
break;
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
TF_LITE_SLICE(uint8_t, kernel_type);
|
TF_LITE_SLICE(uint8_t);
|
||||||
break;
|
break;
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
TF_LITE_SLICE(bool, kernel_type);
|
TF_LITE_SLICE(bool);
|
||||||
break;
|
break;
|
||||||
case kTfLiteString:
|
case kTfLiteString:
|
||||||
TF_LITE_SLICE(string, kernel_type);
|
TF_LITE_SLICE(string);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
@ -37,7 +38,7 @@ namespace strided_slice {
|
|||||||
|
|
||||||
enum KernelType {
|
enum KernelType {
|
||||||
kReference,
|
kReference,
|
||||||
// TODO(b/175642009): add kGenericOptimized
|
kGenericOptimized,
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
@ -190,51 +191,43 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
|
StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
|
||||||
|
|
||||||
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
|
#define TF_LITE_STRIDED_SLICE(data_type) \
|
||||||
kernel_type::StridedSlice<data_type>( \
|
{ \
|
||||||
op_params, GetTensorShape(op_context.input), op_context.input, \
|
if (kernel_type == kGenericOptimized) { \
|
||||||
GetTensorShape(op_context.output), op_context.output)
|
optimized_ops::StridedSlice<data_type>( \
|
||||||
|
op_params, GetTensorShape(op_context.input), op_context.input, \
|
||||||
|
GetTensorShape(op_context.output), op_context.output); \
|
||||||
|
} else { \
|
||||||
|
reference_ops::StridedSlice<data_type>( \
|
||||||
|
op_params, GetTensorShape(op_context.input), op_context.input, \
|
||||||
|
GetTensorShape(op_context.output), op_context.output); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
switch (op_context.input->type) {
|
switch (op_context.input->type) {
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(float);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, float);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(int32_t);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(int64_t);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(uint8_t);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(int8_t);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt16:
|
case kTfLiteInt16:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(int16_t);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, int16_t);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(bool);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, bool);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteString:
|
case kTfLiteString:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_STRIDED_SLICE(string);
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, string);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
TF_LITE_KERNEL_LOG(context,
|
TF_LITE_KERNEL_LOG(context,
|
||||||
@ -257,7 +250,10 @@ TfLiteRegistration* Register_STRIDED_SLICE_REF() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_STRIDED_SLICE() {
|
TfLiteRegistration* Register_STRIDED_SLICE() {
|
||||||
return Register_STRIDED_SLICE_REF();
|
static TfLiteRegistration r = {
|
||||||
|
nullptr, nullptr, strided_slice::Prepare,
|
||||||
|
strided_slice::Eval<strided_slice::kGenericOptimized>};
|
||||||
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
|
Loading…
Reference in New Issue
Block a user