Optimize int8 broadcast min.

PiperOrigin-RevId: 311665392
Change-Id: I566547f44975d3d88cb7a17e8c6418a4a186ccda
This commit is contained in:
Renjie Liu 2020-05-14 21:28:44 -07:00 committed by TensorFlower Gardener
parent 5cf4311435
commit 28899d991f
2 changed files with 124 additions and 23 deletions

View File

@ -7963,14 +7963,59 @@ inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
}
}
inline void BroadcastMaximumFivefold(
const ArithmeticParams& unswitched_params,
const RuntimeShape& unswitched_input1_shape,
const int8* unswitched_input1_data,
const RuntimeShape& unswitched_input2_shape,
const int8* unswitched_input2_data, const RuntimeShape& output_shape,
int8* output_data) {
ruy::profiler::ScopeLabel label("BroadcastMaximumFivefoldInt8/8bit");
// Assume input1 & input2 have the same scale & zero point.
inline void MinimumElementwise(int size, const ArithmeticParams& params,
const int8* input1_data, const int8* input2_data,
int8* output_data) {
int i = 0;
#ifdef USE_NEON
for (; i <= size - 16; i += 16) {
const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
const int8x16_t min_data =
vminq_s8(input1_val_original, input2_val_original);
vst1q_s8(output_data + i, min_data);
}
#endif // USE_NEON
for (; i < size; ++i) {
const int8 input1_val = input1_data[i];
const int8 input2_val = input2_data[i];
output_data[i] = std::min(input1_val, input2_val);
}
}
inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
int8 input1_data, const int8* input2_data,
int8* output_data) {
int i = 0;
#ifdef USE_NEON
const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
for (; i <= size - 16; i += 16) {
const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
const int8x16_t min_data =
vminq_s8(input1_val_original, input2_val_original);
vst1q_s8(output_data + i, min_data);
}
#endif // USE_NEON
for (; i < size; ++i) {
const int8 input2_val = input2_data[i];
output_data[i] = std::min(input1_data, input2_val);
}
}
template <typename ElementwiseF, typename ScalarBroadcastF>
inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
const RuntimeShape& unswitched_input1_shape,
const int8* unswitched_input1_data,
const RuntimeShape& unswitched_input2_shape,
const int8* unswitched_input2_data,
const RuntimeShape& output_shape,
int8* output_data,
ElementwiseF elementwise_f,
ScalarBroadcastF scalar_broadcast_f,
const std::string& label_name) {
ruy::profiler::ScopeLabel label(label_name);
ArithmeticParams switched_params = unswitched_params;
switched_params.input1_offset = unswitched_params.input2_offset;
@ -8000,9 +8045,8 @@ inline void BroadcastMaximumFivefold(
const int8* input2_data_reset = input2_data;
// In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
// between input shapes. y3 for input 1 is always broadcast, and so the
// dimension there is 1, whereas optionally y1 might be broadcast for input 2.
// Put another way,
// input1.shape.FlatSize = y0 * y1 * y2 * y4,
// dimension there is 1, whereas optionally y1 might be broadcast for
// input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
// input2.shape.FlatSize = y0 * y2 * y3 * y4.
int y0 = params.broadcast_shape[0];
int y1 = params.broadcast_shape[1];
@ -8018,8 +8062,8 @@ inline void BroadcastMaximumFivefold(
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
for (int i3 = 0; i3 < y3; ++i3) {
MaximumElementwise(y4, params, input1_data_ptr, input2_data_ptr,
output_data_ptr);
elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
output_data_ptr);
input2_data_ptr += y4;
output_data_ptr += y4;
}
@ -8031,23 +8075,23 @@ inline void BroadcastMaximumFivefold(
input2_data_reset = input2_data_ptr;
}
} else {
// Special case of y4 == 1, in which the innermost loop is a single element
// and can be combined with the next (y3) as an inner broadcast.
// Special case of y4 == 1, in which the innermost loop is a single
// element and can be combined with the next (y3) as an inner broadcast.
//
// Note that this handles the case of pure scalar broadcast when
// y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
// broadcast with batch (as y2 > 1).
//
// NOTE The process is the same as the above general case except simplified
// for y4 == 1 and the loop over y3 is contained within the
// NOTE The process is the same as the above general case except
// simplified for y4 == 1 and the loop over y3 is contained within the
// AddScalarBroadcast function.
for (int i0 = 0; i0 < y0; ++i0) {
const int8* input2_data_ptr = nullptr;
for (int i1 = 0; i1 < y1; ++i1) {
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
MaximumScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
output_data_ptr);
scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
output_data_ptr);
input2_data_ptr += y3;
output_data_ptr += y3;
input1_data_ptr += 1;
@ -8058,7 +8102,6 @@ inline void BroadcastMaximumFivefold(
}
}
// TODO(b/156140316): Try to unify the broadcast dispatch logic for binary ops.
template <typename Op>
inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
@ -8073,8 +8116,30 @@ inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
output_data, op);
}
BroadcastMaximumFivefold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data,
MaximumElementwise, MaximumScalarBroadcast,
"BroadcastMaximumFivefoldInt8/8bit");
}
template <typename Op>
inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int8* input1_data,
const RuntimeShape& input2_shape,
const int8* input2_data,
const RuntimeShape& output_shape,
int8* output_data, Op op) {
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
return reference_ops::MaximumMinimumBroadcastSlow(
input1_shape, input1_data, input2_shape, input2_data, output_shape,
output_data, op);
}
BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data,
MinimumElementwise, MinimumScalarBroadcast,
"BroadcastMinimumFivefoldInt8/8bit");
}
} // namespace optimized_ops

View File

@ -125,6 +125,31 @@ void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MaximumOp>(
MaximumOp::template op<int8>);
}
// Minimum generic opt int8.
template <>
void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MinimumOp>(
TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
tflite::ArithmeticParams op_params;
const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
&op_params);
if (need_broadcast) {
optimized_ops::BroadcastMinimumDispatch(
op_params, GetTensorShape(op_context.input1),
GetTensorData<int8>(op_context.input1),
GetTensorShape(op_context.input2),
GetTensorData<int8>(op_context.input2),
GetTensorShape(op_context.output),
GetTensorData<int8>(op_context.output), MinimumOp::template op<int8>);
return;
}
reference_ops::MaximumMinimumBroadcastSlow(
GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
MinimumOp::template op<int8>);
}
template <KernelType kernel_type, typename OpType>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
@ -186,10 +211,21 @@ TfLiteRegistration* Register_MINIMUM_REF() {
maximum_minimum::MinimumOp>};
return &r;
}
TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() {
static TfLiteRegistration r = {
nullptr, nullptr, maximum_minimum::Prepare,
maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
maximum_minimum::MinimumOp>};
return &r;
}
TfLiteRegistration* Register_MAXIMUM() {
return Register_MAXIMUM_GENERIC_OPT();
}
TfLiteRegistration* Register_MINIMUM() { return Register_MINIMUM_REF(); }
TfLiteRegistration* Register_MINIMUM() {
return Register_MINIMUM_GENERIC_OPT();
}
} // namespace builtin
} // namespace ops