Migrate uint8/float broadcast mul to use binary broadcast fivefold
PiperOrigin-RevId: 314253577 Change-Id: I99e59937ecb693228cc2118fbcc4d8cd663fef03
This commit is contained in:
parent
dcfd9ba9ac
commit
918731364a
|
@ -2565,144 +2565,6 @@ inline void Mul(const ArithmeticParams& params,
|
|||
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||
}
|
||||
|
||||
inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
|
||||
const RuntimeShape& unswitched_input1_shape,
|
||||
const uint8* unswitched_input1_data,
|
||||
const RuntimeShape& unswitched_input2_shape,
|
||||
const uint8* unswitched_input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
uint8* output_data) {
|
||||
ruy::profiler::ScopeLabel label("BroadcastMulFivefold/8bit");
|
||||
|
||||
ArithmeticParams switched_params = unswitched_params;
|
||||
switched_params.input1_offset = unswitched_params.input2_offset;
|
||||
switched_params.input2_offset = unswitched_params.input1_offset;
|
||||
|
||||
const bool use_unswitched =
|
||||
unswitched_params.broadcast_category ==
|
||||
tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
|
||||
|
||||
const ArithmeticParams& params =
|
||||
use_unswitched ? unswitched_params : switched_params;
|
||||
const uint8* input1_data =
|
||||
use_unswitched ? unswitched_input1_data : unswitched_input2_data;
|
||||
const uint8* input2_data =
|
||||
use_unswitched ? unswitched_input2_data : unswitched_input1_data;
|
||||
|
||||
// Fivefold nested loops. The second input resets its position for each
|
||||
// iteration of the second loop. The first input resets its position at the
|
||||
// beginning of the fourth loop. The innermost loop is an elementwise Mul of
|
||||
// sections of the arrays.
|
||||
uint8* output_data_ptr = output_data;
|
||||
const uint8* input1_data_ptr = input1_data;
|
||||
const uint8* input2_data_reset = input2_data;
|
||||
int y0 = params.broadcast_shape[0];
|
||||
int y1 = params.broadcast_shape[1];
|
||||
int y2 = params.broadcast_shape[2];
|
||||
int y3 = params.broadcast_shape[3];
|
||||
int y4 = params.broadcast_shape[4];
|
||||
if (y4 > 1) {
|
||||
for (int i0 = 0; i0 < y0; ++i0) {
|
||||
const uint8* input2_data_ptr = nullptr;
|
||||
for (int i1 = 0; i1 < y1; ++i1) {
|
||||
input2_data_ptr = input2_data_reset;
|
||||
for (int i2 = 0; i2 < y2; ++i2) {
|
||||
for (int i3 = 0; i3 < y3; ++i3) {
|
||||
MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
|
||||
output_data_ptr);
|
||||
input2_data_ptr += y4;
|
||||
output_data_ptr += y4;
|
||||
}
|
||||
input1_data_ptr += y4;
|
||||
}
|
||||
}
|
||||
input2_data_reset = input2_data_ptr;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = 0; i0 < y0; ++i0) {
|
||||
const uint8* input2_data_ptr = nullptr;
|
||||
for (int i1 = 0; i1 < y1; ++i1) {
|
||||
input2_data_ptr = input2_data_reset;
|
||||
for (int i2 = 0; i2 < y2; ++i2) {
|
||||
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
|
||||
output_data_ptr);
|
||||
input2_data_ptr += y3;
|
||||
output_data_ptr += y3;
|
||||
++input1_data_ptr;
|
||||
}
|
||||
}
|
||||
input2_data_reset = input2_data_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void BroadcastMulFivefold(const ArithmeticParams& params,
|
||||
const RuntimeShape& unswitched_input1_shape,
|
||||
const float* unswitched_input1_data,
|
||||
const RuntimeShape& unswitched_input2_shape,
|
||||
const float* unswitched_input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
ruy::profiler::ScopeLabel label("BroadcastMulFivefold/float");
|
||||
|
||||
const bool use_unswitched =
|
||||
params.broadcast_category ==
|
||||
tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
|
||||
|
||||
const float* input1_data =
|
||||
use_unswitched ? unswitched_input1_data : unswitched_input2_data;
|
||||
const float* input2_data =
|
||||
use_unswitched ? unswitched_input2_data : unswitched_input1_data;
|
||||
|
||||
// Fivefold nested loops. The second input resets its position for each
|
||||
// iteration of the second loop. The first input resets its position at the
|
||||
// beginning of the fourth loop. The innermost loop is an elementwise Mul of
|
||||
// sections of the arrays.
|
||||
float* output_data_ptr = output_data;
|
||||
const float* input1_data_ptr = input1_data;
|
||||
const float* input2_data_reset = input2_data;
|
||||
int y0 = params.broadcast_shape[0];
|
||||
int y1 = params.broadcast_shape[1];
|
||||
int y2 = params.broadcast_shape[2];
|
||||
int y3 = params.broadcast_shape[3];
|
||||
int y4 = params.broadcast_shape[4];
|
||||
if (y4 > 1) {
|
||||
for (int i0 = 0; i0 < y0; ++i0) {
|
||||
const float* input2_data_ptr = nullptr;
|
||||
for (int i1 = 0; i1 < y1; ++i1) {
|
||||
input2_data_ptr = input2_data_reset;
|
||||
for (int i2 = 0; i2 < y2; ++i2) {
|
||||
for (int i3 = 0; i3 < y3; ++i3) {
|
||||
MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
|
||||
output_data_ptr);
|
||||
input2_data_ptr += y4;
|
||||
output_data_ptr += y4;
|
||||
}
|
||||
input1_data_ptr += y4;
|
||||
}
|
||||
}
|
||||
input2_data_reset = input2_data_ptr;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = 0; i0 < y0; ++i0) {
|
||||
const float* input2_data_ptr = nullptr;
|
||||
for (int i1 = 0; i1 < y1; ++i1) {
|
||||
input2_data_ptr = input2_data_reset;
|
||||
for (int i2 = 0; i2 < y2; ++i2) {
|
||||
// The input may be switched here, but the common parameters here
|
||||
// do not matter as they will not influence the float math execution.
|
||||
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
|
||||
output_data_ptr);
|
||||
input2_data_ptr += y3;
|
||||
output_data_ptr += y3;
|
||||
++input1_data_ptr;
|
||||
}
|
||||
}
|
||||
input2_data_reset = input2_data_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void BroadcastMulDispatch(
|
||||
const ArithmeticParams& params, const RuntimeShape& input1_shape,
|
||||
|
@ -2713,10 +2575,41 @@ inline void BroadcastMulDispatch(
|
|||
input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
BroadcastMulFivefold(params, input1_shape, input1_data, input2_shape,
|
||||
input2_data, output_shape, output_data);
|
||||
BinaryBroadcastFiveFold(
|
||||
params, input1_shape, input1_data, input2_shape, input2_data,
|
||||
output_shape, output_data,
|
||||
static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
|
||||
T*)>(MulElementwise),
|
||||
static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
|
||||
MulSimpleBroadcast));
|
||||
}
|
||||
|
||||
inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
|
||||
const RuntimeShape& unswitched_input1_shape,
|
||||
const uint8* unswitched_input1_data,
|
||||
const RuntimeShape& unswitched_input2_shape,
|
||||
const uint8* unswitched_input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
uint8* output_data) {
|
||||
BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
|
||||
unswitched_input1_data, unswitched_input2_shape,
|
||||
unswitched_input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
inline void BroadcastMulFivefold(const ArithmeticParams& params,
|
||||
const RuntimeShape& unswitched_input1_shape,
|
||||
const float* unswitched_input1_data,
|
||||
const RuntimeShape& unswitched_input2_shape,
|
||||
const float* unswitched_input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
|
||||
unswitched_input2_shape, unswitched_input2_data,
|
||||
output_shape, output_data);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
|
||||
// dimensionality if the runtime code does a single loop over one dimension
|
||||
// that handles broadcasting as the base case. The code generator would then
|
||||
|
|
Loading…
Reference in New Issue