Optimize int8 broadcast min.
PiperOrigin-RevId: 311665392 Change-Id: I566547f44975d3d88cb7a17e8c6418a4a186ccda
This commit is contained in:
parent
5cf4311435
commit
28899d991f
|
@ -7963,14 +7963,59 @@ inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void BroadcastMaximumFivefold(
|
// Assume input1 & input2 have the same scale & zero point.
|
||||||
const ArithmeticParams& unswitched_params,
|
inline void MinimumElementwise(int size, const ArithmeticParams& params,
|
||||||
|
const int8* input1_data, const int8* input2_data,
|
||||||
|
int8* output_data) {
|
||||||
|
int i = 0;
|
||||||
|
#ifdef USE_NEON
|
||||||
|
for (; i <= size - 16; i += 16) {
|
||||||
|
const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
|
||||||
|
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
|
||||||
|
const int8x16_t min_data =
|
||||||
|
vminq_s8(input1_val_original, input2_val_original);
|
||||||
|
vst1q_s8(output_data + i, min_data);
|
||||||
|
}
|
||||||
|
#endif // USE_NEON
|
||||||
|
for (; i < size; ++i) {
|
||||||
|
const int8 input1_val = input1_data[i];
|
||||||
|
const int8 input2_val = input2_data[i];
|
||||||
|
output_data[i] = std::min(input1_val, input2_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
|
||||||
|
int8 input1_data, const int8* input2_data,
|
||||||
|
int8* output_data) {
|
||||||
|
int i = 0;
|
||||||
|
|
||||||
|
#ifdef USE_NEON
|
||||||
|
const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
|
||||||
|
for (; i <= size - 16; i += 16) {
|
||||||
|
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
|
||||||
|
const int8x16_t min_data =
|
||||||
|
vminq_s8(input1_val_original, input2_val_original);
|
||||||
|
vst1q_s8(output_data + i, min_data);
|
||||||
|
}
|
||||||
|
#endif // USE_NEON
|
||||||
|
for (; i < size; ++i) {
|
||||||
|
const int8 input2_val = input2_data[i];
|
||||||
|
output_data[i] = std::min(input1_data, input2_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementwiseF, typename ScalarBroadcastF>
|
||||||
|
inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
|
||||||
const RuntimeShape& unswitched_input1_shape,
|
const RuntimeShape& unswitched_input1_shape,
|
||||||
const int8* unswitched_input1_data,
|
const int8* unswitched_input1_data,
|
||||||
const RuntimeShape& unswitched_input2_shape,
|
const RuntimeShape& unswitched_input2_shape,
|
||||||
const int8* unswitched_input2_data, const RuntimeShape& output_shape,
|
const int8* unswitched_input2_data,
|
||||||
int8* output_data) {
|
const RuntimeShape& output_shape,
|
||||||
ruy::profiler::ScopeLabel label("BroadcastMaximumFivefoldInt8/8bit");
|
int8* output_data,
|
||||||
|
ElementwiseF elementwise_f,
|
||||||
|
ScalarBroadcastF scalar_broadcast_f,
|
||||||
|
const std::string& label_name) {
|
||||||
|
ruy::profiler::ScopeLabel label(label_name);
|
||||||
|
|
||||||
ArithmeticParams switched_params = unswitched_params;
|
ArithmeticParams switched_params = unswitched_params;
|
||||||
switched_params.input1_offset = unswitched_params.input2_offset;
|
switched_params.input1_offset = unswitched_params.input2_offset;
|
||||||
|
@ -8000,9 +8045,8 @@ inline void BroadcastMaximumFivefold(
|
||||||
const int8* input2_data_reset = input2_data;
|
const int8* input2_data_reset = input2_data;
|
||||||
// In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
|
// In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
|
||||||
// between input shapes. y3 for input 1 is always broadcast, and so the
|
// between input shapes. y3 for input 1 is always broadcast, and so the
|
||||||
// dimension there is 1, whereas optionally y1 might be broadcast for input 2.
|
// dimension there is 1, whereas optionally y1 might be broadcast for
|
||||||
// Put another way,
|
// input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
|
||||||
// input1.shape.FlatSize = y0 * y1 * y2 * y4,
|
|
||||||
// input2.shape.FlatSize = y0 * y2 * y3 * y4.
|
// input2.shape.FlatSize = y0 * y2 * y3 * y4.
|
||||||
int y0 = params.broadcast_shape[0];
|
int y0 = params.broadcast_shape[0];
|
||||||
int y1 = params.broadcast_shape[1];
|
int y1 = params.broadcast_shape[1];
|
||||||
|
@ -8018,7 +8062,7 @@ inline void BroadcastMaximumFivefold(
|
||||||
input2_data_ptr = input2_data_reset;
|
input2_data_ptr = input2_data_reset;
|
||||||
for (int i2 = 0; i2 < y2; ++i2) {
|
for (int i2 = 0; i2 < y2; ++i2) {
|
||||||
for (int i3 = 0; i3 < y3; ++i3) {
|
for (int i3 = 0; i3 < y3; ++i3) {
|
||||||
MaximumElementwise(y4, params, input1_data_ptr, input2_data_ptr,
|
elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
|
||||||
output_data_ptr);
|
output_data_ptr);
|
||||||
input2_data_ptr += y4;
|
input2_data_ptr += y4;
|
||||||
output_data_ptr += y4;
|
output_data_ptr += y4;
|
||||||
|
@ -8031,22 +8075,22 @@ inline void BroadcastMaximumFivefold(
|
||||||
input2_data_reset = input2_data_ptr;
|
input2_data_reset = input2_data_ptr;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Special case of y4 == 1, in which the innermost loop is a single element
|
// Special case of y4 == 1, in which the innermost loop is a single
|
||||||
// and can be combined with the next (y3) as an inner broadcast.
|
// element and can be combined with the next (y3) as an inner broadcast.
|
||||||
//
|
//
|
||||||
// Note that this handles the case of pure scalar broadcast when
|
// Note that this handles the case of pure scalar broadcast when
|
||||||
// y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
|
// y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
|
||||||
// broadcast with batch (as y2 > 1).
|
// broadcast with batch (as y2 > 1).
|
||||||
//
|
//
|
||||||
// NOTE The process is the same as the above general case except simplified
|
// NOTE The process is the same as the above general case except
|
||||||
// for y4 == 1 and the loop over y3 is contained within the
|
// simplified for y4 == 1 and the loop over y3 is contained within the
|
||||||
// AddScalarBroadcast function.
|
// AddScalarBroadcast function.
|
||||||
for (int i0 = 0; i0 < y0; ++i0) {
|
for (int i0 = 0; i0 < y0; ++i0) {
|
||||||
const int8* input2_data_ptr = nullptr;
|
const int8* input2_data_ptr = nullptr;
|
||||||
for (int i1 = 0; i1 < y1; ++i1) {
|
for (int i1 = 0; i1 < y1; ++i1) {
|
||||||
input2_data_ptr = input2_data_reset;
|
input2_data_ptr = input2_data_reset;
|
||||||
for (int i2 = 0; i2 < y2; ++i2) {
|
for (int i2 = 0; i2 < y2; ++i2) {
|
||||||
MaximumScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
|
scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
|
||||||
output_data_ptr);
|
output_data_ptr);
|
||||||
input2_data_ptr += y3;
|
input2_data_ptr += y3;
|
||||||
output_data_ptr += y3;
|
output_data_ptr += y3;
|
||||||
|
@ -8058,7 +8102,6 @@ inline void BroadcastMaximumFivefold(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/156140316): Try to unify the broadcast dispatch logic for binary ops.
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
|
inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
|
||||||
const RuntimeShape& input1_shape,
|
const RuntimeShape& input1_shape,
|
||||||
|
@ -8073,8 +8116,30 @@ inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
|
||||||
output_data, op);
|
output_data, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
BroadcastMaximumFivefold(params, input1_shape, input1_data, input2_shape,
|
BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
|
||||||
input2_data, output_shape, output_data);
|
input2_data, output_shape, output_data,
|
||||||
|
MaximumElementwise, MaximumScalarBroadcast,
|
||||||
|
"BroadcastMaximumFivefoldInt8/8bit");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape,
|
||||||
|
const int8* input1_data,
|
||||||
|
const RuntimeShape& input2_shape,
|
||||||
|
const int8* input2_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
int8* output_data, Op op) {
|
||||||
|
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
|
||||||
|
return reference_ops::MaximumMinimumBroadcastSlow(
|
||||||
|
input1_shape, input1_data, input2_shape, input2_data, output_shape,
|
||||||
|
output_data, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
|
||||||
|
input2_data, output_shape, output_data,
|
||||||
|
MinimumElementwise, MinimumScalarBroadcast,
|
||||||
|
"BroadcastMinimumFivefoldInt8/8bit");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace optimized_ops
|
} // namespace optimized_ops
|
||||||
|
|
|
@ -125,6 +125,31 @@ void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MaximumOp>(
|
||||||
MaximumOp::template op<int8>);
|
MaximumOp::template op<int8>);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Minimum generic opt int8.
|
||||||
|
template <>
|
||||||
|
void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MinimumOp>(
|
||||||
|
TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
|
||||||
|
tflite::ArithmeticParams op_params;
|
||||||
|
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
||||||
|
GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
|
||||||
|
&op_params);
|
||||||
|
if (need_broadcast) {
|
||||||
|
optimized_ops::BroadcastMinimumDispatch(
|
||||||
|
op_params, GetTensorShape(op_context.input1),
|
||||||
|
GetTensorData<int8>(op_context.input1),
|
||||||
|
GetTensorShape(op_context.input2),
|
||||||
|
GetTensorData<int8>(op_context.input2),
|
||||||
|
GetTensorShape(op_context.output),
|
||||||
|
GetTensorData<int8>(op_context.output), MinimumOp::template op<int8>);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
reference_ops::MaximumMinimumBroadcastSlow(
|
||||||
|
GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
|
||||||
|
GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
|
||||||
|
GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
|
||||||
|
MinimumOp::template op<int8>);
|
||||||
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type, typename OpType>
|
template <KernelType kernel_type, typename OpType>
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
OpContext op_context(context, node);
|
OpContext op_context(context, node);
|
||||||
|
@ -186,10 +211,21 @@ TfLiteRegistration* Register_MINIMUM_REF() {
|
||||||
maximum_minimum::MinimumOp>};
|
maximum_minimum::MinimumOp>};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
nullptr, nullptr, maximum_minimum::Prepare,
|
||||||
|
maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
|
||||||
|
maximum_minimum::MinimumOp>};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_MAXIMUM() {
|
TfLiteRegistration* Register_MAXIMUM() {
|
||||||
return Register_MAXIMUM_GENERIC_OPT();
|
return Register_MAXIMUM_GENERIC_OPT();
|
||||||
}
|
}
|
||||||
TfLiteRegistration* Register_MINIMUM() { return Register_MINIMUM_REF(); }
|
TfLiteRegistration* Register_MINIMUM() {
|
||||||
|
return Register_MINIMUM_GENERIC_OPT();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
Loading…
Reference in New Issue