Optimize broadcast add float path by reusing broadcast five fold logic.
PiperOrigin-RevId: 274325783
This commit is contained in:
parent
7b42c3b12a
commit
895659f67b
@ -42,8 +42,6 @@ constexpr int kInputTensor2 = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
struct OpData {
|
||||
bool requires_broadcast;
|
||||
|
||||
// These fields are used in both the general 8-bit -> 8bit quantized path,
|
||||
// and the special 16-bit -> 16bit quantized path
|
||||
int input1_shift;
|
||||
@ -64,7 +62,6 @@ struct OpData {
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* data = new OpData;
|
||||
data->requires_broadcast = false;
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -86,10 +83,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
|
||||
output->type = input2->type;
|
||||
|
||||
data->requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
const bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
|
||||
TfLiteIntArray* output_size = nullptr;
|
||||
if (data->requires_broadcast) {
|
||||
if (requires_broadcast) {
|
||||
TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
|
||||
context, input1, input2, &output_size));
|
||||
} else {
|
||||
@ -179,11 +176,15 @@ template <KernelType kernel_type>
|
||||
void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
const OpData* data, const TfLiteTensor* input1,
|
||||
const TfLiteTensor* input2, TfLiteTensor* output) {
|
||||
tflite::ArithmeticParams op_params;
|
||||
// requires_flat_size_broadcast is used for BroadcastAdd4DSlow.
|
||||
const bool requires_flat_size_broadcast = !HaveSameShapes(input1, input2);
|
||||
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
||||
GetTensorShape(input1), GetTensorShape(input2), &op_params);
|
||||
#define TF_LITE_ADD(type, opname, data_type) \
|
||||
data_type output_activation_min, output_activation_max; \
|
||||
CalculateActivationRange(params->activation, &output_activation_min, \
|
||||
&output_activation_max); \
|
||||
tflite::ArithmeticParams op_params; \
|
||||
SetActivationParams(output_activation_min, output_activation_max, \
|
||||
&op_params); \
|
||||
type::opname(op_params, GetTensorShape(input1), \
|
||||
@ -192,13 +193,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
GetTensorData<data_type>(output))
|
||||
if (output->type == kTfLiteInt32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (data->requires_broadcast) {
|
||||
if (requires_flat_size_broadcast) {
|
||||
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_ADD(reference_ops, Add, int32_t);
|
||||
}
|
||||
} else {
|
||||
if (data->requires_broadcast) {
|
||||
if (requires_flat_size_broadcast) {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_ADD(optimized_ops, Add, int32_t);
|
||||
@ -206,13 +207,15 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
|
||||
}
|
||||
} else if (output->type == kTfLiteFloat32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (data->requires_broadcast) {
|
||||
if (requires_flat_size_broadcast) {
|
||||
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float);
|
||||
} else {
|
||||
TF_LITE_ADD(reference_ops, Add, float);
|
||||
}
|
||||
} else {
|
||||
if (data->requires_broadcast) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, float);
|
||||
} else if (requires_flat_size_broadcast) {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float);
|
||||
} else {
|
||||
TF_LITE_ADD(optimized_ops, Add, float);
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include <type_traits>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/add.h"
|
||||
|
||||
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||
#include <Accelerate/Accelerate.h>
|
||||
@ -1485,14 +1486,11 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void Add(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape, const float* input1_data,
|
||||
const RuntimeShape& input2_shape, const float* input2_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Add");
|
||||
|
||||
inline void AddElementwise(int size, const ArithmeticParams& params,
|
||||
const float* input1_data, const float* input2_data,
|
||||
float* output_data) {
|
||||
int i = 0;
|
||||
const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
|
||||
#ifdef USE_NEON
|
||||
const auto activation_min = vdupq_n_f32(params.float_activation_min);
|
||||
const auto activation_max = vdupq_n_f32(params.float_activation_max);
|
||||
@ -1539,6 +1537,16 @@ inline void Add(const ArithmeticParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void Add(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape, const float* input1_data,
|
||||
const RuntimeShape& input2_shape, const float* input2_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Add");
|
||||
const int flat_size =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
AddElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
|
||||
// Element-wise add that can often be used for inner loop of broadcast add as
|
||||
// well as the non-broadcast add.
|
||||
inline void AddElementwise(int size, const ArithmeticParams& params,
|
||||
@ -1907,6 +1915,76 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
|
||||
const RuntimeShape& unswitched_input1_shape,
|
||||
const float* unswitched_input1_data,
|
||||
const RuntimeShape& unswitched_input2_shape,
|
||||
const float* unswitched_input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/float");
|
||||
|
||||
ArithmeticParams switched_params = unswitched_params;
|
||||
|
||||
const bool use_unswitched =
|
||||
unswitched_params.broadcast_category ==
|
||||
tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
|
||||
|
||||
const ArithmeticParams& params =
|
||||
use_unswitched ? unswitched_params : switched_params;
|
||||
const float* input1_data =
|
||||
use_unswitched ? unswitched_input1_data : unswitched_input2_data;
|
||||
const float* input2_data =
|
||||
use_unswitched ? unswitched_input2_data : unswitched_input1_data;
|
||||
|
||||
// Fivefold nested loops. The second input resets its position for each
|
||||
// iteration of the second loop. The first input resets its position at the
|
||||
// beginning of the fourth loop. The innermost loop is an elementwise add of
|
||||
// sections of the arrays.
|
||||
float* output_data_ptr = output_data;
|
||||
const float* input1_data_ptr = input1_data;
|
||||
const float* input2_data_reset = input2_data;
|
||||
// In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
|
||||
// between input shapes. y3 for input 1 is always broadcast, and so the
|
||||
// dimension there is 1, whereas optionally y1 might be broadcast for input 2.
|
||||
// Put another way,
|
||||
// input1.shape.FlatSize = y0 * y1 * y2 * y4,
|
||||
// input2.shape.FlatSize = y0 * y2 * y3 * y4.
|
||||
int y0 = params.broadcast_shape[0];
|
||||
int y1 = params.broadcast_shape[1];
|
||||
int y2 = params.broadcast_shape[2];
|
||||
int y3 = params.broadcast_shape[3];
|
||||
int y4 = params.broadcast_shape[4];
|
||||
if (y4 > 1) {
|
||||
// General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
|
||||
// dimension.
|
||||
for (int i0 = 0; i0 < y0; ++i0) {
|
||||
const float* input2_data_ptr = nullptr;
|
||||
for (int i1 = 0; i1 < y1; ++i1) {
|
||||
input2_data_ptr = input2_data_reset;
|
||||
for (int i2 = 0; i2 < y2; ++i2) {
|
||||
for (int i3 = 0; i3 < y3; ++i3) {
|
||||
AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
|
||||
output_data_ptr);
|
||||
input2_data_ptr += y4;
|
||||
output_data_ptr += y4;
|
||||
}
|
||||
// We have broadcast y4 of input1 data y3 times, and now move on.
|
||||
input1_data_ptr += y4;
|
||||
}
|
||||
}
|
||||
// We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
|
||||
input2_data_reset = input2_data_ptr;
|
||||
}
|
||||
} else {
|
||||
// TODO(renjieliu): Optimze for scalar broadcast case.
|
||||
reference_ops::BroadcastAdd4DSlow(
|
||||
unswitched_params, unswitched_input1_shape, unswitched_input1_data,
|
||||
unswitched_input2_shape, unswitched_input2_data, output_shape,
|
||||
output_data);
|
||||
}
|
||||
}
|
||||
|
||||
inline void Mul(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape, const float* input1_data,
|
||||
const RuntimeShape& input2_shape, const float* input2_data,
|
||||
|
Loading…
Reference in New Issue
Block a user