Optimize for float prelu.

PiperOrigin-RevId: 332182719
Change-Id: Idee8c15d156a5dab62771d126fbec1eb5546fba9
This commit is contained in:
Renjie Liu 2020-09-17 01:19:39 -07:00 committed by TensorFlower Gardener
parent 151800d58a
commit 853223a456
3 changed files with 188 additions and 20 deletions

View File

@ -1181,6 +1181,7 @@ T ApplyPrelu(T input, T alpha) {
return input >= 0.0 ? input : input * alpha;
}
template <KernelType kernel_type>
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
@ -1188,18 +1189,38 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(alpha), GetTensorData<float>(alpha),
GetTensorShape(output), GetTensorData<float>(output),
ApplyPrelu<float>);
if (kernel_type == kGenericOptimized) {
tflite::ArithmeticParams op_params;
bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
GetTensorShape(input), GetTensorShape(alpha), &op_params);
if (need_broadcast) {
optimized_ops::BroadcastPReluDispatch(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(alpha), GetTensorData<float>(alpha),
GetTensorShape(output), GetTensorData<float>(output),
ApplyPrelu<float>);
} else {
const int flat_size =
MatchingElementsSize(GetTensorShape(input), GetTensorShape(alpha),
GetTensorShape(output));
optimized_ops::PReluElementWise(
flat_size, op_params, GetTensorData<float>(alpha),
GetTensorData<float>(input), GetTensorData<float>(output));
}
} else {
reference_ops::BinaryFunction<float, float, float>(
GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(alpha), GetTensorData<float>(alpha),
GetTensorShape(output), GetTensorData<float>(output),
ApplyPrelu<float>);
if (data->requires_broadcast) {
reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(alpha), GetTensorData<float>(alpha),
GetTensorShape(output), GetTensorData<float>(output),
ApplyPrelu<float>);
} else {
reference_ops::BinaryFunction<float, float, float>(
GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(alpha), GetTensorData<float>(alpha),
GetTensorShape(output), GetTensorData<float>(output),
ApplyPrelu<float>);
}
}
return kTfLiteOk;
} break;
@ -1460,10 +1481,17 @@ TfLiteRegistration* Register_LOG_SOFTMAX() {
return &r;
}
TfLiteRegistration* Register_PRELU_REF() {
static TfLiteRegistration r = {
activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
activations::PreluEval<activations::kReference>};
return &r;
}
TfLiteRegistration* Register_PRELU() {
static TfLiteRegistration r = {activations::PreluInit, activations::PreluFree,
activations::PreluPrepare,
activations::PreluEval};
static TfLiteRegistration r = {
activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
activations::PreluEval<activations::kGenericOptimized>};
return &r;
}

View File

@ -50,6 +50,10 @@ TfLiteRegistration* Register_LOGISTIC_REF();
TfLiteRegistration* Register_LOGISTIC_GENERIC_OPT();
TfLiteRegistration* Register_LOGISTIC_FIXED_POINT_OPT();
// PRelu kernel registrations.
TfLiteRegistration* Register_PRELU_REF();
TfLiteRegistration* Register_PRELU();
} // namespace builtin
} // namespace ops
@ -2031,6 +2035,11 @@ TEST(QuantizedActivationsOpTest, LogSoftmaxInt8) {
}));
}
const auto kPReluKernelMap = new std::map<string, TfLiteRegistration*>({
{"Reference", ops::builtin::Register_PRELU_REF()},
{"GenericOptimized", ops::builtin::Register_PRELU()},
});
// A base class of PRelu op model. It provides the constructor for
// FloatPReluOpModel and QuantizedPReluOpModel.
class BasePReluOpModel : public SingleOpModel {
@ -2087,7 +2096,14 @@ class QuantizedPReluOpModel : public BasePReluOpModel {
}
};
TEST(FloatActivationsOpTest, PRelu) {
class PReluOpTest : public SingleOpTest {
protected:
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
return *kPReluKernelMap;
}
};
TEST_P(PReluOpTest, PReluFloat32) {
FloatPReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
{TensorType_FLOAT32, {1, 1, 3}});
@ -2107,7 +2123,7 @@ TEST(FloatActivationsOpTest, PRelu) {
}));
}
TEST(FloatActivationsOpTest, PReluSameShapes) {
TEST_P(PReluOpTest, PReluFloat32SameShapes) {
FloatPReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
{TensorType_FLOAT32, {1, 2, 2, 3}});
@ -2132,7 +2148,7 @@ TEST(FloatActivationsOpTest, PReluSameShapes) {
}));
}
TEST(QuantizedActivationsOpTest, PRelu) {
TEST_P(PReluOpTest, PReluUInt8) {
const float kMin = -1;
const float kMax = 127.f / 128.f;
QuantizedPReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
@ -2162,7 +2178,7 @@ TEST(QuantizedActivationsOpTest, PRelu) {
}));
}
TEST(QuantizedActivationsOpTest, PReluSameShapes) {
TEST_P(PReluOpTest, PReluUInt8SameShapes) {
const float kMin = -1;
const float kMax = 127.f / 128.f;
QuantizedPReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
@ -2197,7 +2213,7 @@ TEST(QuantizedActivationsOpTest, PReluSameShapes) {
}));
}
TEST(QuantizedActivationsOpTest, PReluInt8) {
TEST_P(PReluOpTest, PReluInt8) {
const float kMin = -1;
const float kMax = 127.f / 128.f;
QuantizedPReluOpModel m({TensorType_INT8, {1, 2, 2, 3}, kMin, kMax},
@ -2227,7 +2243,7 @@ TEST(QuantizedActivationsOpTest, PReluInt8) {
}));
}
TEST(QuantizedActivationsOpTest, PReluInt8SameShapes) {
TEST_P(PReluOpTest, PReluInt8SameShapes) {
const float kMin = -1;
const float kMax = 127.f / 128.f;
QuantizedPReluOpModel m({TensorType_INT8, {1, 2, 2, 3}, kMin, kMax},
@ -2303,5 +2319,9 @@ INSTANTIATE_TEST_SUITE_P(
LogisticOpTest, LogisticOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kLogisticKernelMap)));
INSTANTIATE_TEST_SUITE_P(
PReluOpTest, PReluOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kPReluKernelMap)));
} // namespace
} // namespace tflite

View File

@ -8172,6 +8172,126 @@ void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
}
inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
float alpha, const float* input_data,
float* output_data) {
ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
int i = 0;
#ifdef USE_NEON
const float32x4_t zero_dup = vdupq_n_f32(0.0f);
const float32x4_t alpha_dup = vdupq_n_f32(alpha);
for (; i <= size - 16; i += 16) {
const float32x4_t input1 = vld1q_f32(input_data + i);
const float32x4_t input2 = vld1q_f32(input_data + i + 4);
const float32x4_t input3 = vld1q_f32(input_data + i + 8);
const float32x4_t input4 = vld1q_f32(input_data + i + 12);
const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
vst1q_f32(output_data + i, result1);
const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
vst1q_f32(output_data + i + 4, result2);
const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
vst1q_f32(output_data + i + 8, result3);
const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
vst1q_f32(output_data + i + 12, result4);
}
for (; i <= size - 4; i += 4) {
const float32x4_t input = vld1q_f32(input_data + i);
const float32x4_t temp = vmulq_f32(input, alpha_dup);
const uint32x4_t mask = vcgeq_f32(input, zero_dup);
const float32x4_t result = vbslq_f32(mask, input, temp);
vst1q_f32(output_data + i, result);
}
#endif // USE_NEON
for (; i < size; ++i) {
const float input = input_data[i];
output_data[i] = input >= 0.f ? input : input * alpha;
}
}
inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
const float* alpha_data, const float* input_data,
float* output_data) {
ruy::profiler::ScopeLabel label("PreluElementWise/float");
int i = 0;
#ifdef USE_NEON
const float32x4_t zero_dup = vdupq_n_f32(0.0f);
for (; i <= flat_size - 16; i += 16) {
const float32x4_t input1 = vld1q_f32(input_data + i);
const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
const float32x4_t input2 = vld1q_f32(input_data + i + 4);
const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
const float32x4_t input3 = vld1q_f32(input_data + i + 8);
const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
const float32x4_t input4 = vld1q_f32(input_data + i + 12);
const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
const float32x4_t temp1 = vmulq_f32(input1, alpha1);
const float32x4_t temp2 = vmulq_f32(input2, alpha2);
const float32x4_t temp3 = vmulq_f32(input3, alpha3);
const float32x4_t temp4 = vmulq_f32(input4, alpha4);
const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
vst1q_f32(output_data + i, result1);
const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
vst1q_f32(output_data + i + 4, result2);
const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
vst1q_f32(output_data + i + 8, result3);
const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
vst1q_f32(output_data + i + 12, result4);
}
for (; i <= flat_size - 4; i += 4) {
const float32x4_t input = vld1q_f32(input_data + i);
const float32x4_t alpha = vld1q_f32(alpha_data + i);
const float32x4_t temp = vmulq_f32(input, alpha);
const uint32x4_t mask = vcgeq_f32(input, zero_dup);
const float32x4_t result = vbslq_f32(mask, input, temp);
vst1q_f32(output_data + i, result);
}
#endif // USE_NEON
for (; i < flat_size; ++i) {
const float input = input_data[i];
const float alpha = alpha_data[i];
output_data[i] = input >= 0.f ? input : input * alpha;
}
}
inline void BroadcastPReluDispatch(
const ArithmeticParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& alpha_shape,
const float* alpha_data, const RuntimeShape& output_shape,
float* output_data, float (*func)(float, float)) {
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
input_shape, input_data, alpha_shape, alpha_data, output_shape,
output_data, func);
}
BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
alpha_data, output_shape, output_data,
PReluElementWise, PReluScalarBroadcast);
}
} // namespace optimized_ops
} // namespace tflite