Optimize int8 broadcast min.

PiperOrigin-RevId: 311665392
Change-Id: I566547f44975d3d88cb7a17e8c6418a4a186ccda
This commit is contained in:
Renjie Liu 2020-05-14 21:28:44 -07:00 committed by TensorFlower Gardener
parent 5cf4311435
commit 28899d991f
2 changed files with 124 additions and 23 deletions

View File

@ -7963,14 +7963,59 @@ inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
} }
} }
inline void BroadcastMaximumFivefold( // Assume input1 & input2 have the same scale & zero point.
const ArithmeticParams& unswitched_params, inline void MinimumElementwise(int size, const ArithmeticParams& params,
const int8* input1_data, const int8* input2_data,
int8* output_data) {
int i = 0;
#ifdef USE_NEON
for (; i <= size - 16; i += 16) {
const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
const int8x16_t min_data =
vminq_s8(input1_val_original, input2_val_original);
vst1q_s8(output_data + i, min_data);
}
#endif // USE_NEON
for (; i < size; ++i) {
const int8 input1_val = input1_data[i];
const int8 input2_val = input2_data[i];
output_data[i] = std::min(input1_val, input2_val);
}
}
inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
int8 input1_data, const int8* input2_data,
int8* output_data) {
int i = 0;
#ifdef USE_NEON
const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
for (; i <= size - 16; i += 16) {
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
const int8x16_t min_data =
vminq_s8(input1_val_original, input2_val_original);
vst1q_s8(output_data + i, min_data);
}
#endif // USE_NEON
for (; i < size; ++i) {
const int8 input2_val = input2_data[i];
output_data[i] = std::min(input1_data, input2_val);
}
}
template <typename ElementwiseF, typename ScalarBroadcastF>
inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
const RuntimeShape& unswitched_input1_shape, const RuntimeShape& unswitched_input1_shape,
const int8* unswitched_input1_data, const int8* unswitched_input1_data,
const RuntimeShape& unswitched_input2_shape, const RuntimeShape& unswitched_input2_shape,
const int8* unswitched_input2_data, const RuntimeShape& output_shape, const int8* unswitched_input2_data,
int8* output_data) { const RuntimeShape& output_shape,
ruy::profiler::ScopeLabel label("BroadcastMaximumFivefoldInt8/8bit"); int8* output_data,
ElementwiseF elementwise_f,
ScalarBroadcastF scalar_broadcast_f,
const std::string& label_name) {
ruy::profiler::ScopeLabel label(label_name);
ArithmeticParams switched_params = unswitched_params; ArithmeticParams switched_params = unswitched_params;
switched_params.input1_offset = unswitched_params.input2_offset; switched_params.input1_offset = unswitched_params.input2_offset;
@ -8000,9 +8045,8 @@ inline void BroadcastMaximumFivefold(
const int8* input2_data_reset = input2_data; const int8* input2_data_reset = input2_data;
// In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
// between input shapes. y3 for input 1 is always broadcast, and so the // between input shapes. y3 for input 1 is always broadcast, and so the
// dimension there is 1, whereas optionally y1 might be broadcast for input 2. // dimension there is 1, whereas optionally y1 might be broadcast for
// Put another way, // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
// input1.shape.FlatSize = y0 * y1 * y2 * y4,
// input2.shape.FlatSize = y0 * y2 * y3 * y4. // input2.shape.FlatSize = y0 * y2 * y3 * y4.
int y0 = params.broadcast_shape[0]; int y0 = params.broadcast_shape[0];
int y1 = params.broadcast_shape[1]; int y1 = params.broadcast_shape[1];
@ -8018,7 +8062,7 @@ inline void BroadcastMaximumFivefold(
input2_data_ptr = input2_data_reset; input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) { for (int i2 = 0; i2 < y2; ++i2) {
for (int i3 = 0; i3 < y3; ++i3) { for (int i3 = 0; i3 < y3; ++i3) {
MaximumElementwise(y4, params, input1_data_ptr, input2_data_ptr, elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
output_data_ptr); output_data_ptr);
input2_data_ptr += y4; input2_data_ptr += y4;
output_data_ptr += y4; output_data_ptr += y4;
@ -8031,22 +8075,22 @@ inline void BroadcastMaximumFivefold(
input2_data_reset = input2_data_ptr; input2_data_reset = input2_data_ptr;
} }
} else { } else {
// Special case of y4 == 1, in which the innermost loop is a single element // Special case of y4 == 1, in which the innermost loop is a single
// and can be combined with the next (y3) as an inner broadcast. // element and can be combined with the next (y3) as an inner broadcast.
// //
// Note that this handles the case of pure scalar broadcast when // Note that this handles the case of pure scalar broadcast when
// y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
// broadcast with batch (as y2 > 1). // broadcast with batch (as y2 > 1).
// //
// NOTE The process is the same as the above general case except simplified // NOTE The process is the same as the above general case except
// for y4 == 1 and the loop over y3 is contained within the // simplified for y4 == 1 and the loop over y3 is contained within the
// AddScalarBroadcast function. // AddScalarBroadcast function.
for (int i0 = 0; i0 < y0; ++i0) { for (int i0 = 0; i0 < y0; ++i0) {
const int8* input2_data_ptr = nullptr; const int8* input2_data_ptr = nullptr;
for (int i1 = 0; i1 < y1; ++i1) { for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset; input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) { for (int i2 = 0; i2 < y2; ++i2) {
MaximumScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr, scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
output_data_ptr); output_data_ptr);
input2_data_ptr += y3; input2_data_ptr += y3;
output_data_ptr += y3; output_data_ptr += y3;
@ -8058,7 +8102,6 @@ inline void BroadcastMaximumFivefold(
} }
} }
// TODO(b/156140316): Try to unify the broadcast dispatch logic for binary ops.
template <typename Op> template <typename Op>
inline void BroadcastMaximumDispatch(const ArithmeticParams& params, inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const RuntimeShape& input1_shape,
@ -8073,8 +8116,30 @@ inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
output_data, op); output_data, op);
} }
BroadcastMaximumFivefold(params, input1_shape, input1_data, input2_shape, BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data); input2_data, output_shape, output_data,
MaximumElementwise, MaximumScalarBroadcast,
"BroadcastMaximumFivefoldInt8/8bit");
}
template <typename Op>
inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int8* input1_data,
const RuntimeShape& input2_shape,
const int8* input2_data,
const RuntimeShape& output_shape,
int8* output_data, Op op) {
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
return reference_ops::MaximumMinimumBroadcastSlow(
input1_shape, input1_data, input2_shape, input2_data, output_shape,
output_data, op);
}
BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data,
MinimumElementwise, MinimumScalarBroadcast,
"BroadcastMinimumFivefoldInt8/8bit");
} }
} // namespace optimized_ops } // namespace optimized_ops

View File

@ -125,6 +125,31 @@ void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MaximumOp>(
MaximumOp::template op<int8>); MaximumOp::template op<int8>);
} }
// Minimum generic opt int8.
template <>
void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MinimumOp>(
TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
tflite::ArithmeticParams op_params;
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
&op_params);
if (need_broadcast) {
optimized_ops::BroadcastMinimumDispatch(
op_params, GetTensorShape(op_context.input1),
GetTensorData<int8>(op_context.input1),
GetTensorShape(op_context.input2),
GetTensorData<int8>(op_context.input2),
GetTensorShape(op_context.output),
GetTensorData<int8>(op_context.output), MinimumOp::template op<int8>);
return;
}
reference_ops::MaximumMinimumBroadcastSlow(
GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
MinimumOp::template op<int8>);
}
template <KernelType kernel_type, typename OpType> template <KernelType kernel_type, typename OpType>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node); OpContext op_context(context, node);
@ -186,10 +211,21 @@ TfLiteRegistration* Register_MINIMUM_REF() {
maximum_minimum::MinimumOp>}; maximum_minimum::MinimumOp>};
return &r; return &r;
} }
TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() {
static TfLiteRegistration r = {
nullptr, nullptr, maximum_minimum::Prepare,
maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
maximum_minimum::MinimumOp>};
return &r;
}
TfLiteRegistration* Register_MAXIMUM() { TfLiteRegistration* Register_MAXIMUM() {
return Register_MAXIMUM_GENERIC_OPT(); return Register_MAXIMUM_GENERIC_OPT();
} }
TfLiteRegistration* Register_MINIMUM() { return Register_MINIMUM_REF(); } TfLiteRegistration* Register_MINIMUM() {
return Register_MINIMUM_GENERIC_OPT();
}
} // namespace builtin } // namespace builtin
} // namespace ops } // namespace ops