Migrate uint8/float broadcast mul to use binary broadcast fivefold
PiperOrigin-RevId: 314253577 Change-Id: I99e59937ecb693228cc2118fbcc4d8cd663fef03
This commit is contained in:
parent
dcfd9ba9ac
commit
918731364a
@ -2565,144 +2565,6 @@ inline void Mul(const ArithmeticParams& params,
|
|||||||
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
|
|
||||||
const RuntimeShape& unswitched_input1_shape,
|
|
||||||
const uint8* unswitched_input1_data,
|
|
||||||
const RuntimeShape& unswitched_input2_shape,
|
|
||||||
const uint8* unswitched_input2_data,
|
|
||||||
const RuntimeShape& output_shape,
|
|
||||||
uint8* output_data) {
|
|
||||||
ruy::profiler::ScopeLabel label("BroadcastMulFivefold/8bit");
|
|
||||||
|
|
||||||
ArithmeticParams switched_params = unswitched_params;
|
|
||||||
switched_params.input1_offset = unswitched_params.input2_offset;
|
|
||||||
switched_params.input2_offset = unswitched_params.input1_offset;
|
|
||||||
|
|
||||||
const bool use_unswitched =
|
|
||||||
unswitched_params.broadcast_category ==
|
|
||||||
tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
|
|
||||||
|
|
||||||
const ArithmeticParams& params =
|
|
||||||
use_unswitched ? unswitched_params : switched_params;
|
|
||||||
const uint8* input1_data =
|
|
||||||
use_unswitched ? unswitched_input1_data : unswitched_input2_data;
|
|
||||||
const uint8* input2_data =
|
|
||||||
use_unswitched ? unswitched_input2_data : unswitched_input1_data;
|
|
||||||
|
|
||||||
// Fivefold nested loops. The second input resets its position for each
|
|
||||||
// iteration of the second loop. The first input resets its position at the
|
|
||||||
// beginning of the fourth loop. The innermost loop is an elementwise Mul of
|
|
||||||
// sections of the arrays.
|
|
||||||
uint8* output_data_ptr = output_data;
|
|
||||||
const uint8* input1_data_ptr = input1_data;
|
|
||||||
const uint8* input2_data_reset = input2_data;
|
|
||||||
int y0 = params.broadcast_shape[0];
|
|
||||||
int y1 = params.broadcast_shape[1];
|
|
||||||
int y2 = params.broadcast_shape[2];
|
|
||||||
int y3 = params.broadcast_shape[3];
|
|
||||||
int y4 = params.broadcast_shape[4];
|
|
||||||
if (y4 > 1) {
|
|
||||||
for (int i0 = 0; i0 < y0; ++i0) {
|
|
||||||
const uint8* input2_data_ptr = nullptr;
|
|
||||||
for (int i1 = 0; i1 < y1; ++i1) {
|
|
||||||
input2_data_ptr = input2_data_reset;
|
|
||||||
for (int i2 = 0; i2 < y2; ++i2) {
|
|
||||||
for (int i3 = 0; i3 < y3; ++i3) {
|
|
||||||
MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
|
|
||||||
output_data_ptr);
|
|
||||||
input2_data_ptr += y4;
|
|
||||||
output_data_ptr += y4;
|
|
||||||
}
|
|
||||||
input1_data_ptr += y4;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
input2_data_reset = input2_data_ptr;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i0 = 0; i0 < y0; ++i0) {
|
|
||||||
const uint8* input2_data_ptr = nullptr;
|
|
||||||
for (int i1 = 0; i1 < y1; ++i1) {
|
|
||||||
input2_data_ptr = input2_data_reset;
|
|
||||||
for (int i2 = 0; i2 < y2; ++i2) {
|
|
||||||
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
|
|
||||||
output_data_ptr);
|
|
||||||
input2_data_ptr += y3;
|
|
||||||
output_data_ptr += y3;
|
|
||||||
++input1_data_ptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
input2_data_reset = input2_data_ptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void BroadcastMulFivefold(const ArithmeticParams& params,
|
|
||||||
const RuntimeShape& unswitched_input1_shape,
|
|
||||||
const float* unswitched_input1_data,
|
|
||||||
const RuntimeShape& unswitched_input2_shape,
|
|
||||||
const float* unswitched_input2_data,
|
|
||||||
const RuntimeShape& output_shape,
|
|
||||||
float* output_data) {
|
|
||||||
ruy::profiler::ScopeLabel label("BroadcastMulFivefold/float");
|
|
||||||
|
|
||||||
const bool use_unswitched =
|
|
||||||
params.broadcast_category ==
|
|
||||||
tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
|
|
||||||
|
|
||||||
const float* input1_data =
|
|
||||||
use_unswitched ? unswitched_input1_data : unswitched_input2_data;
|
|
||||||
const float* input2_data =
|
|
||||||
use_unswitched ? unswitched_input2_data : unswitched_input1_data;
|
|
||||||
|
|
||||||
// Fivefold nested loops. The second input resets its position for each
|
|
||||||
// iteration of the second loop. The first input resets its position at the
|
|
||||||
// beginning of the fourth loop. The innermost loop is an elementwise Mul of
|
|
||||||
// sections of the arrays.
|
|
||||||
float* output_data_ptr = output_data;
|
|
||||||
const float* input1_data_ptr = input1_data;
|
|
||||||
const float* input2_data_reset = input2_data;
|
|
||||||
int y0 = params.broadcast_shape[0];
|
|
||||||
int y1 = params.broadcast_shape[1];
|
|
||||||
int y2 = params.broadcast_shape[2];
|
|
||||||
int y3 = params.broadcast_shape[3];
|
|
||||||
int y4 = params.broadcast_shape[4];
|
|
||||||
if (y4 > 1) {
|
|
||||||
for (int i0 = 0; i0 < y0; ++i0) {
|
|
||||||
const float* input2_data_ptr = nullptr;
|
|
||||||
for (int i1 = 0; i1 < y1; ++i1) {
|
|
||||||
input2_data_ptr = input2_data_reset;
|
|
||||||
for (int i2 = 0; i2 < y2; ++i2) {
|
|
||||||
for (int i3 = 0; i3 < y3; ++i3) {
|
|
||||||
MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
|
|
||||||
output_data_ptr);
|
|
||||||
input2_data_ptr += y4;
|
|
||||||
output_data_ptr += y4;
|
|
||||||
}
|
|
||||||
input1_data_ptr += y4;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
input2_data_reset = input2_data_ptr;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i0 = 0; i0 < y0; ++i0) {
|
|
||||||
const float* input2_data_ptr = nullptr;
|
|
||||||
for (int i1 = 0; i1 < y1; ++i1) {
|
|
||||||
input2_data_ptr = input2_data_reset;
|
|
||||||
for (int i2 = 0; i2 < y2; ++i2) {
|
|
||||||
// The input may be switched here, but the common parameters here
|
|
||||||
// do not matter as they will not influence the float math execution.
|
|
||||||
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
|
|
||||||
output_data_ptr);
|
|
||||||
input2_data_ptr += y3;
|
|
||||||
output_data_ptr += y3;
|
|
||||||
++input1_data_ptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
input2_data_reset = input2_data_ptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void BroadcastMulDispatch(
|
inline void BroadcastMulDispatch(
|
||||||
const ArithmeticParams& params, const RuntimeShape& input1_shape,
|
const ArithmeticParams& params, const RuntimeShape& input1_shape,
|
||||||
@ -2713,10 +2575,41 @@ inline void BroadcastMulDispatch(
|
|||||||
input2_data, output_shape, output_data);
|
input2_data, output_shape, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
BroadcastMulFivefold(params, input1_shape, input1_data, input2_shape,
|
BinaryBroadcastFiveFold(
|
||||||
input2_data, output_shape, output_data);
|
params, input1_shape, input1_data, input2_shape, input2_data,
|
||||||
|
output_shape, output_data,
|
||||||
|
static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
|
||||||
|
T*)>(MulElementwise),
|
||||||
|
static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
|
||||||
|
MulSimpleBroadcast));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
|
||||||
|
const RuntimeShape& unswitched_input1_shape,
|
||||||
|
const uint8* unswitched_input1_data,
|
||||||
|
const RuntimeShape& unswitched_input2_shape,
|
||||||
|
const uint8* unswitched_input2_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
uint8* output_data) {
|
||||||
|
BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
|
||||||
|
unswitched_input1_data, unswitched_input2_shape,
|
||||||
|
unswitched_input2_data, output_shape, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void BroadcastMulFivefold(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& unswitched_input1_shape,
|
||||||
|
const float* unswitched_input1_data,
|
||||||
|
const RuntimeShape& unswitched_input2_shape,
|
||||||
|
const float* unswitched_input2_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
float* output_data) {
|
||||||
|
BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
|
||||||
|
unswitched_input2_shape, unswitched_input2_data,
|
||||||
|
output_shape, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
|
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
|
||||||
// dimensionality if the runtime code does a single loop over one dimension
|
// dimensionality if the runtime code does a single loop over one dimension
|
||||||
// that handles broadcasting as the base case. The code generator would then
|
// that handles broadcasting as the base case. The code generator would then
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user