Migrate uint8/float broadcast mul to use binary broadcast fivefold

PiperOrigin-RevId: 314253577
Change-Id: I99e59937ecb693228cc2118fbcc4d8cd663fef03
This commit is contained in:
Renjie Liu 2020-06-01 19:50:39 -07:00 committed by TensorFlower Gardener
parent dcfd9ba9ac
commit 918731364a
1 changed files with 33 additions and 140 deletions

View File

@ -2565,144 +2565,6 @@ inline void Mul(const ArithmeticParams& params,
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
}
inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
const RuntimeShape& unswitched_input1_shape,
const uint8* unswitched_input1_data,
const RuntimeShape& unswitched_input2_shape,
const uint8* unswitched_input2_data,
const RuntimeShape& output_shape,
uint8* output_data) {
ruy::profiler::ScopeLabel label("BroadcastMulFivefold/8bit");
ArithmeticParams switched_params = unswitched_params;
switched_params.input1_offset = unswitched_params.input2_offset;
switched_params.input2_offset = unswitched_params.input1_offset;
const bool use_unswitched =
unswitched_params.broadcast_category ==
tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
const ArithmeticParams& params =
use_unswitched ? unswitched_params : switched_params;
const uint8* input1_data =
use_unswitched ? unswitched_input1_data : unswitched_input2_data;
const uint8* input2_data =
use_unswitched ? unswitched_input2_data : unswitched_input1_data;
// Fivefold nested loops. The second input resets its position for each
// iteration of the second loop. The first input resets its position at the
// beginning of the fourth loop. The innermost loop is an elementwise Mul of
// sections of the arrays.
uint8* output_data_ptr = output_data;
const uint8* input1_data_ptr = input1_data;
const uint8* input2_data_reset = input2_data;
int y0 = params.broadcast_shape[0];
int y1 = params.broadcast_shape[1];
int y2 = params.broadcast_shape[2];
int y3 = params.broadcast_shape[3];
int y4 = params.broadcast_shape[4];
if (y4 > 1) {
for (int i0 = 0; i0 < y0; ++i0) {
const uint8* input2_data_ptr = nullptr;
for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
for (int i3 = 0; i3 < y3; ++i3) {
MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
output_data_ptr);
input2_data_ptr += y4;
output_data_ptr += y4;
}
input1_data_ptr += y4;
}
}
input2_data_reset = input2_data_ptr;
}
} else {
for (int i0 = 0; i0 < y0; ++i0) {
const uint8* input2_data_ptr = nullptr;
for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
output_data_ptr);
input2_data_ptr += y3;
output_data_ptr += y3;
++input1_data_ptr;
}
}
input2_data_reset = input2_data_ptr;
}
}
}
inline void BroadcastMulFivefold(const ArithmeticParams& params,
const RuntimeShape& unswitched_input1_shape,
const float* unswitched_input1_data,
const RuntimeShape& unswitched_input2_shape,
const float* unswitched_input2_data,
const RuntimeShape& output_shape,
float* output_data) {
ruy::profiler::ScopeLabel label("BroadcastMulFivefold/float");
const bool use_unswitched =
params.broadcast_category ==
tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
const float* input1_data =
use_unswitched ? unswitched_input1_data : unswitched_input2_data;
const float* input2_data =
use_unswitched ? unswitched_input2_data : unswitched_input1_data;
// Fivefold nested loops. The second input resets its position for each
// iteration of the second loop. The first input resets its position at the
// beginning of the fourth loop. The innermost loop is an elementwise Mul of
// sections of the arrays.
float* output_data_ptr = output_data;
const float* input1_data_ptr = input1_data;
const float* input2_data_reset = input2_data;
int y0 = params.broadcast_shape[0];
int y1 = params.broadcast_shape[1];
int y2 = params.broadcast_shape[2];
int y3 = params.broadcast_shape[3];
int y4 = params.broadcast_shape[4];
if (y4 > 1) {
for (int i0 = 0; i0 < y0; ++i0) {
const float* input2_data_ptr = nullptr;
for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
for (int i3 = 0; i3 < y3; ++i3) {
MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
output_data_ptr);
input2_data_ptr += y4;
output_data_ptr += y4;
}
input1_data_ptr += y4;
}
}
input2_data_reset = input2_data_ptr;
}
} else {
for (int i0 = 0; i0 < y0; ++i0) {
const float* input2_data_ptr = nullptr;
for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
// The input may be switched here, but the common parameters here
// do not matter as they will not influence the float math execution.
MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
output_data_ptr);
input2_data_ptr += y3;
output_data_ptr += y3;
++input1_data_ptr;
}
}
input2_data_reset = input2_data_ptr;
}
}
}
template <typename T>
inline void BroadcastMulDispatch(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
@ -2713,10 +2575,41 @@ inline void BroadcastMulDispatch(
input2_data, output_shape, output_data);
}
BroadcastMulFivefold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
BinaryBroadcastFiveFold(
params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data,
static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
T*)>(MulElementwise),
static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
MulSimpleBroadcast));
}
inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
const RuntimeShape& unswitched_input1_shape,
const uint8* unswitched_input1_data,
const RuntimeShape& unswitched_input2_shape,
const uint8* unswitched_input2_data,
const RuntimeShape& output_shape,
uint8* output_data) {
BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
unswitched_input1_data, unswitched_input2_shape,
unswitched_input2_data, output_shape, output_data);
}
inline void BroadcastMulFivefold(const ArithmeticParams& params,
const RuntimeShape& unswitched_input1_shape,
const float* unswitched_input1_data,
const RuntimeShape& unswitched_input2_shape,
const float* unswitched_input2_data,
const RuntimeShape& output_shape,
float* output_data) {
BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
unswitched_input2_shape, unswitched_input2_data,
output_shape, output_data);
}
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then