From 28899d991f8f7443a04343fe9f308a1ea28a0795 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Thu, 14 May 2020 21:28:44 -0700 Subject: [PATCH] Optimize int8 broadcast min. PiperOrigin-RevId: 311665392 Change-Id: I566547f44975d3d88cb7a17e8c6418a4a186ccda --- .../internal/optimized/optimized_ops.h | 109 ++++++++++++++---- tensorflow/lite/kernels/maximum_minimum.cc | 38 +++++- 2 files changed, 124 insertions(+), 23 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index c72400f33a5..b18f0f4bb5a 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -7963,14 +7963,59 @@ inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params, } } -inline void BroadcastMaximumFivefold( - const ArithmeticParams& unswitched_params, - const RuntimeShape& unswitched_input1_shape, - const int8* unswitched_input1_data, - const RuntimeShape& unswitched_input2_shape, - const int8* unswitched_input2_data, const RuntimeShape& output_shape, - int8* output_data) { - ruy::profiler::ScopeLabel label("BroadcastMaximumFivefoldInt8/8bit"); +// Assume input1 & input2 have the same scale & zero point. +inline void MinimumElementwise(int size, const ArithmeticParams& params, + const int8* input1_data, const int8* input2_data, + int8* output_data) { + int i = 0; +#ifdef USE_NEON + for (; i <= size - 16; i += 16) { + const int8x16_t input1_val_original = vld1q_s8(input1_data + i); + const int8x16_t input2_val_original = vld1q_s8(input2_data + i); + const int8x16_t min_data = + vminq_s8(input1_val_original, input2_val_original); + vst1q_s8(output_data + i, min_data); + } +#endif // USE_NEON + for (; i < size; ++i) { + const int8 input1_val = input1_data[i]; + const int8 input2_val = input2_data[i]; + output_data[i] = std::min(input1_val, input2_val); + } +} + +inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params, + int8 input1_data, const int8* input2_data, + int8* output_data) { + int i = 0; + +#ifdef USE_NEON + const int8x16_t input1_val_original = vdupq_n_s8(input1_data); + for (; i <= size - 16; i += 16) { + const int8x16_t input2_val_original = vld1q_s8(input2_data + i); + const int8x16_t min_data = + vminq_s8(input1_val_original, input2_val_original); + vst1q_s8(output_data + i, min_data); + } +#endif // USE_NEON + for (; i < size; ++i) { + const int8 input2_val = input2_data[i]; + output_data[i] = std::min(input1_data, input2_val); + } +} + +template +inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const int8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const int8* unswitched_input2_data, + const RuntimeShape& output_shape, + int8* output_data, + ElementwiseF elementwise_f, + ScalarBroadcastF scalar_broadcast_f, + const std::string& label_name) { + ruy::profiler::ScopeLabel label(label_name); ArithmeticParams switched_params = unswitched_params; switched_params.input1_offset = unswitched_params.input2_offset; @@ -8000,9 +8045,8 @@ inline void BroadcastMaximumFivefold( const int8* input2_data_reset = input2_data; // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared // between input shapes. y3 for input 1 is always broadcast, and so the - // dimension there is 1, whereas optionally y1 might be broadcast for input 2. - // Put another way, - // input1.shape.FlatSize = y0 * y1 * y2 * y4, + // dimension there is 1, whereas optionally y1 might be broadcast for + // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4, // input2.shape.FlatSize = y0 * y2 * y3 * y4. int y0 = params.broadcast_shape[0]; int y1 = params.broadcast_shape[1]; @@ -8018,8 +8062,8 @@ inline void BroadcastMaximumFivefold( input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { for (int i3 = 0; i3 < y3; ++i3) { - MaximumElementwise(y4, params, input1_data_ptr, input2_data_ptr, - output_data_ptr); + elementwise_f(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); input2_data_ptr += y4; output_data_ptr += y4; } @@ -8031,23 +8075,23 @@ inline void BroadcastMaximumFivefold( input2_data_reset = input2_data_ptr; } } else { - // Special case of y4 == 1, in which the innermost loop is a single element - // and can be combined with the next (y3) as an inner broadcast. + // Special case of y4 == 1, in which the innermost loop is a single + // element and can be combined with the next (y3) as an inner broadcast. // // Note that this handles the case of pure scalar broadcast when // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar // broadcast with batch (as y2 > 1). // - // NOTE The process is the same as the above general case except simplified - // for y4 == 1 and the loop over y3 is contained within the + // NOTE The process is the same as the above general case except + // simplified for y4 == 1 and the loop over y3 is contained within the // AddScalarBroadcast function. for (int i0 = 0; i0 < y0; ++i0) { const int8* input2_data_ptr = nullptr; for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { - MaximumScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr, - output_data_ptr); + scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr, + output_data_ptr); input2_data_ptr += y3; output_data_ptr += y3; input1_data_ptr += 1; @@ -8058,7 +8102,6 @@ inline void BroadcastMaximumFivefold( } } -// TODO(b/156140316): Try to unify the broadcast dispatch logic for binary ops. template inline void BroadcastMaximumDispatch(const ArithmeticParams& params, const RuntimeShape& input1_shape, @@ -8073,8 +8116,30 @@ inline void BroadcastMaximumDispatch(const ArithmeticParams& params, output_data, op); } - BroadcastMaximumFivefold(params, input1_shape, input1_data, input2_shape, - input2_data, output_shape, output_data); + BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data, + MaximumElementwise, MaximumScalarBroadcast, + "BroadcastMaximumFivefoldInt8/8bit"); +} + +template +inline void BroadcastMinimumDispatch(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int8* input1_data, + const RuntimeShape& input2_shape, + const int8* input2_data, + const RuntimeShape& output_shape, + int8* output_data, Op op) { + if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) { + return reference_ops::MaximumMinimumBroadcastSlow( + input1_shape, input1_data, input2_shape, input2_data, output_shape, + output_data, op); + } + + BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data, + MinimumElementwise, MinimumScalarBroadcast, + "BroadcastMinimumFivefoldInt8/8bit"); } } // namespace optimized_ops diff --git a/tensorflow/lite/kernels/maximum_minimum.cc b/tensorflow/lite/kernels/maximum_minimum.cc index abe9647f69e..cad86acd8dd 100644 --- a/tensorflow/lite/kernels/maximum_minimum.cc +++ b/tensorflow/lite/kernels/maximum_minimum.cc @@ -125,6 +125,31 @@ void TFLiteOperation( MaximumOp::template op); } +// Minimum generic opt int8. +template <> +void TFLiteOperation( + TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) { + tflite::ArithmeticParams op_params; + const bool need_broadcast = optimized_ops::ProcessBroadcastShapes( + GetTensorShape(op_context.input1), GetTensorShape(op_context.input2), + &op_params); + if (need_broadcast) { + optimized_ops::BroadcastMinimumDispatch( + op_params, GetTensorShape(op_context.input1), + GetTensorData(op_context.input1), + GetTensorShape(op_context.input2), + GetTensorData(op_context.input2), + GetTensorShape(op_context.output), + GetTensorData(op_context.output), MinimumOp::template op); + return; + } + reference_ops::MaximumMinimumBroadcastSlow( + GetTensorShape(op_context.input1), GetTensorData(op_context.input1), + GetTensorShape(op_context.input2), GetTensorData(op_context.input2), + GetTensorShape(op_context.output), GetTensorData(op_context.output), + MinimumOp::template op); +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); @@ -186,10 +211,21 @@ TfLiteRegistration* Register_MINIMUM_REF() { maximum_minimum::MinimumOp>}; return &r; } + +TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, maximum_minimum::Prepare, + maximum_minimum::Eval}; + return &r; +} + TfLiteRegistration* Register_MAXIMUM() { return Register_MAXIMUM_GENERIC_OPT(); } -TfLiteRegistration* Register_MINIMUM() { return Register_MINIMUM_REF(); } +TfLiteRegistration* Register_MINIMUM() { + return Register_MINIMUM_GENERIC_OPT(); +} } // namespace builtin } // namespace ops