Non-broadcast Div optimized

This commit is contained in:
Michal W. Tarnowski 2019-04-13 22:00:06 +02:00
parent 55cc896ca4
commit 0840008136

View File

@ -3214,6 +3214,84 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
} }
} }
inline void Div(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const float* input1_data,
const RuntimeShape& input2_shape, const float* input2_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Div");
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
int i = 0;
const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
#ifdef USE_NEON
static constexpr int kNewtonSteps = 2;
static const float32x4_t TWO_F32 = vdupq_n_f32(2.f);
const float32x4_t activation_min = vdupq_n_f32(output_activation_min);
const float32x4_t activation_max = vdupq_n_f32(output_activation_max);
for (; i <= size - 16; i += 16) {
const float32x4_t a10 = vld1q_f32(input1_data + i);
const float32x4_t a11 = vld1q_f32(input1_data + i + 4);
const float32x4_t a12 = vld1q_f32(input1_data + i + 8);
const float32x4_t a13 = vld1q_f32(input1_data + i + 12);
const float32x4_t a20 = vld1q_f32(input2_data + i);
const float32x4_t a21 = vld1q_f32(input2_data + i + 4);
const float32x4_t a22 = vld1q_f32(input2_data + i + 8);
const float32x4_t a23 = vld1q_f32(input2_data + i + 12);
float32x4_t r0 = vrecpeq_f32(a20);
float32x4_t r1 = vrecpeq_f32(a21);
float32x4_t r2 = vrecpeq_f32(a22);
float32x4_t r3 = vrecpeq_f32(a23);
for (int k = 0; k < kNewtonSteps; ++k) {
r0 = vmulq_f32(r0, vsubq_f32(TWO_F32, vmulq_f32(r0, a20)));
r1 = vmulq_f32(r1, vsubq_f32(TWO_F32, vmulq_f32(r1, a21)));
r2 = vmulq_f32(r2, vsubq_f32(TWO_F32, vmulq_f32(r2, a22)));
r3 = vmulq_f32(r3, vsubq_f32(TWO_F32, vmulq_f32(r3, a23)));
}
float32x4_t x0 = vmulq_f32(a10, r0);
float32x4_t x1 = vmulq_f32(a11, r1);
float32x4_t x2 = vmulq_f32(a12, r2);
float32x4_t x3 = vmulq_f32(a13, r3);
x0 = vmaxq_f32(activation_min, x0);
x1 = vmaxq_f32(activation_min, x1);
x2 = vmaxq_f32(activation_min, x2);
x3 = vmaxq_f32(activation_min, x3);
x0 = vminq_f32(activation_max, x0);
x1 = vminq_f32(activation_max, x1);
x2 = vminq_f32(activation_max, x2);
x3 = vminq_f32(activation_max, x3);
vst1q_f32(output_data + i, x0);
vst1q_f32(output_data + i + 4, x1);
vst1q_f32(output_data + i + 8, x2);
vst1q_f32(output_data + i + 12, x3);
}
for (; i <= size - 4; i += 4) {
const float32x4_t a1 = vld1q_f32(input1_data + i);
const float32x4_t a2 = vld1q_f32(input2_data + i);
float32x4_t r = vrecpeq_f32(a2);
for (int k = 0; k < kNewtonSteps; ++k) {
r = vmulq_f32(r, vsubq_f32(TWO_F32, vmulq_f32(r, a2)));
}
float32x4_t x = vmulq_f32(a1, r);
x = vmaxq_f32(activation_min, x);
x = vminq_f32(activation_max, x);
vst1q_f32(output_data + i, x);
}
#endif // NEON
for (; i < size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
output_activation_max);
}
}
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension // dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then // that handles broadcasting as the base case. The code generator would then