Optimize for float prelu.
PiperOrigin-RevId: 332182719 Change-Id: Idee8c15d156a5dab62771d126fbec1eb5546fba9
This commit is contained in:
parent
151800d58a
commit
853223a456
@ -1181,6 +1181,7 @@ T ApplyPrelu(T input, T alpha) {
|
||||
return input >= 0.0 ? input : input * alpha;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* alpha = GetInput(context, node, 1);
|
||||
@ -1188,18 +1189,38 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(alpha), GetTensorData<float>(alpha),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
ApplyPrelu<float>);
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
tflite::ArithmeticParams op_params;
|
||||
bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
||||
GetTensorShape(input), GetTensorShape(alpha), &op_params);
|
||||
if (need_broadcast) {
|
||||
optimized_ops::BroadcastPReluDispatch(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(alpha), GetTensorData<float>(alpha),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
ApplyPrelu<float>);
|
||||
} else {
|
||||
const int flat_size =
|
||||
MatchingElementsSize(GetTensorShape(input), GetTensorShape(alpha),
|
||||
GetTensorShape(output));
|
||||
optimized_ops::PReluElementWise(
|
||||
flat_size, op_params, GetTensorData<float>(alpha),
|
||||
GetTensorData<float>(input), GetTensorData<float>(output));
|
||||
}
|
||||
} else {
|
||||
reference_ops::BinaryFunction<float, float, float>(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(alpha), GetTensorData<float>(alpha),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
ApplyPrelu<float>);
|
||||
if (data->requires_broadcast) {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(alpha), GetTensorData<float>(alpha),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
ApplyPrelu<float>);
|
||||
} else {
|
||||
reference_ops::BinaryFunction<float, float, float>(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(alpha), GetTensorData<float>(alpha),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
ApplyPrelu<float>);
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
@ -1460,10 +1481,17 @@ TfLiteRegistration* Register_LOG_SOFTMAX() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_PRELU_REF() {
|
||||
static TfLiteRegistration r = {
|
||||
activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
|
||||
activations::PreluEval<activations::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_PRELU() {
|
||||
static TfLiteRegistration r = {activations::PreluInit, activations::PreluFree,
|
||||
activations::PreluPrepare,
|
||||
activations::PreluEval};
|
||||
static TfLiteRegistration r = {
|
||||
activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
|
||||
activations::PreluEval<activations::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
|
@ -50,6 +50,10 @@ TfLiteRegistration* Register_LOGISTIC_REF();
|
||||
TfLiteRegistration* Register_LOGISTIC_GENERIC_OPT();
|
||||
TfLiteRegistration* Register_LOGISTIC_FIXED_POINT_OPT();
|
||||
|
||||
// PRelu kernel registrations.
|
||||
TfLiteRegistration* Register_PRELU_REF();
|
||||
TfLiteRegistration* Register_PRELU();
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
|
||||
@ -2031,6 +2035,11 @@ TEST(QuantizedActivationsOpTest, LogSoftmaxInt8) {
|
||||
}));
|
||||
}
|
||||
|
||||
const auto kPReluKernelMap = new std::map<string, TfLiteRegistration*>({
|
||||
{"Reference", ops::builtin::Register_PRELU_REF()},
|
||||
{"GenericOptimized", ops::builtin::Register_PRELU()},
|
||||
});
|
||||
|
||||
// A base class of PRelu op model. It provides the constructor for
|
||||
// FloatPReluOpModel and QuantizedPReluOpModel.
|
||||
class BasePReluOpModel : public SingleOpModel {
|
||||
@ -2087,7 +2096,14 @@ class QuantizedPReluOpModel : public BasePReluOpModel {
|
||||
}
|
||||
};
|
||||
|
||||
TEST(FloatActivationsOpTest, PRelu) {
|
||||
class PReluOpTest : public SingleOpTest {
|
||||
protected:
|
||||
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||
return *kPReluKernelMap;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(PReluOpTest, PReluFloat32) {
|
||||
FloatPReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
|
||||
{TensorType_FLOAT32, {1, 1, 3}});
|
||||
|
||||
@ -2107,7 +2123,7 @@ TEST(FloatActivationsOpTest, PRelu) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(FloatActivationsOpTest, PReluSameShapes) {
|
||||
TEST_P(PReluOpTest, PReluFloat32SameShapes) {
|
||||
FloatPReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
|
||||
{TensorType_FLOAT32, {1, 2, 2, 3}});
|
||||
|
||||
@ -2132,7 +2148,7 @@ TEST(FloatActivationsOpTest, PReluSameShapes) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, PRelu) {
|
||||
TEST_P(PReluOpTest, PReluUInt8) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
QuantizedPReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
|
||||
@ -2162,7 +2178,7 @@ TEST(QuantizedActivationsOpTest, PRelu) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, PReluSameShapes) {
|
||||
TEST_P(PReluOpTest, PReluUInt8SameShapes) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
QuantizedPReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
|
||||
@ -2197,7 +2213,7 @@ TEST(QuantizedActivationsOpTest, PReluSameShapes) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, PReluInt8) {
|
||||
TEST_P(PReluOpTest, PReluInt8) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
QuantizedPReluOpModel m({TensorType_INT8, {1, 2, 2, 3}, kMin, kMax},
|
||||
@ -2227,7 +2243,7 @@ TEST(QuantizedActivationsOpTest, PReluInt8) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(QuantizedActivationsOpTest, PReluInt8SameShapes) {
|
||||
TEST_P(PReluOpTest, PReluInt8SameShapes) {
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
QuantizedPReluOpModel m({TensorType_INT8, {1, 2, 2, 3}, kMin, kMax},
|
||||
@ -2303,5 +2319,9 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
LogisticOpTest, LogisticOpTest,
|
||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kLogisticKernelMap)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PReluOpTest, PReluOpTest,
|
||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kPReluKernelMap)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -8172,6 +8172,126 @@ void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
|
||||
CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
|
||||
}
|
||||
|
||||
inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
|
||||
float alpha, const float* input_data,
|
||||
float* output_data) {
|
||||
ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
|
||||
int i = 0;
|
||||
|
||||
#ifdef USE_NEON
|
||||
const float32x4_t zero_dup = vdupq_n_f32(0.0f);
|
||||
const float32x4_t alpha_dup = vdupq_n_f32(alpha);
|
||||
for (; i <= size - 16; i += 16) {
|
||||
const float32x4_t input1 = vld1q_f32(input_data + i);
|
||||
const float32x4_t input2 = vld1q_f32(input_data + i + 4);
|
||||
const float32x4_t input3 = vld1q_f32(input_data + i + 8);
|
||||
const float32x4_t input4 = vld1q_f32(input_data + i + 12);
|
||||
|
||||
const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
|
||||
const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
|
||||
const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
|
||||
const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
|
||||
|
||||
const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
|
||||
const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
|
||||
const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
|
||||
const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
|
||||
|
||||
const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
|
||||
vst1q_f32(output_data + i, result1);
|
||||
const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
|
||||
vst1q_f32(output_data + i + 4, result2);
|
||||
const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
|
||||
vst1q_f32(output_data + i + 8, result3);
|
||||
const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
|
||||
vst1q_f32(output_data + i + 12, result4);
|
||||
}
|
||||
|
||||
for (; i <= size - 4; i += 4) {
|
||||
const float32x4_t input = vld1q_f32(input_data + i);
|
||||
const float32x4_t temp = vmulq_f32(input, alpha_dup);
|
||||
const uint32x4_t mask = vcgeq_f32(input, zero_dup);
|
||||
const float32x4_t result = vbslq_f32(mask, input, temp);
|
||||
vst1q_f32(output_data + i, result);
|
||||
}
|
||||
#endif // USE_NEON
|
||||
for (; i < size; ++i) {
|
||||
const float input = input_data[i];
|
||||
output_data[i] = input >= 0.f ? input : input * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
|
||||
const float* alpha_data, const float* input_data,
|
||||
float* output_data) {
|
||||
ruy::profiler::ScopeLabel label("PreluElementWise/float");
|
||||
|
||||
int i = 0;
|
||||
#ifdef USE_NEON
|
||||
const float32x4_t zero_dup = vdupq_n_f32(0.0f);
|
||||
for (; i <= flat_size - 16; i += 16) {
|
||||
const float32x4_t input1 = vld1q_f32(input_data + i);
|
||||
const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
|
||||
const float32x4_t input2 = vld1q_f32(input_data + i + 4);
|
||||
const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
|
||||
const float32x4_t input3 = vld1q_f32(input_data + i + 8);
|
||||
const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
|
||||
const float32x4_t input4 = vld1q_f32(input_data + i + 12);
|
||||
const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
|
||||
|
||||
const float32x4_t temp1 = vmulq_f32(input1, alpha1);
|
||||
const float32x4_t temp2 = vmulq_f32(input2, alpha2);
|
||||
const float32x4_t temp3 = vmulq_f32(input3, alpha3);
|
||||
const float32x4_t temp4 = vmulq_f32(input4, alpha4);
|
||||
|
||||
const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
|
||||
const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
|
||||
const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
|
||||
const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
|
||||
|
||||
const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
|
||||
vst1q_f32(output_data + i, result1);
|
||||
const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
|
||||
vst1q_f32(output_data + i + 4, result2);
|
||||
const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
|
||||
vst1q_f32(output_data + i + 8, result3);
|
||||
const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
|
||||
vst1q_f32(output_data + i + 12, result4);
|
||||
}
|
||||
|
||||
for (; i <= flat_size - 4; i += 4) {
|
||||
const float32x4_t input = vld1q_f32(input_data + i);
|
||||
const float32x4_t alpha = vld1q_f32(alpha_data + i);
|
||||
|
||||
const float32x4_t temp = vmulq_f32(input, alpha);
|
||||
const uint32x4_t mask = vcgeq_f32(input, zero_dup);
|
||||
const float32x4_t result = vbslq_f32(mask, input, temp);
|
||||
vst1q_f32(output_data + i, result);
|
||||
}
|
||||
#endif // USE_NEON
|
||||
for (; i < flat_size; ++i) {
|
||||
const float input = input_data[i];
|
||||
const float alpha = alpha_data[i];
|
||||
output_data[i] = input >= 0.f ? input : input * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
inline void BroadcastPReluDispatch(
|
||||
const ArithmeticParams& params, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& alpha_shape,
|
||||
const float* alpha_data, const RuntimeShape& output_shape,
|
||||
float* output_data, float (*func)(float, float)) {
|
||||
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
|
||||
return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
|
||||
input_shape, input_data, alpha_shape, alpha_data, output_shape,
|
||||
output_data, func);
|
||||
}
|
||||
|
||||
BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
|
||||
alpha_data, output_shape, output_data,
|
||||
PReluElementWise, PReluScalarBroadcast);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user