diff --git a/tensorflow/lite/kernels/add.cc b/tensorflow/lite/kernels/add.cc index f4bfd8d3248..32a7c100ce5 100644 --- a/tensorflow/lite/kernels/add.cc +++ b/tensorflow/lite/kernels/add.cc @@ -220,30 +220,38 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input2, TfLiteTensor* output) { if (output->type == kTfLiteUInt8) { + tflite::ArithmeticParams op_params; + op_params.left_shift = data->left_shift; + op_params.input1_offset = data->input1_offset; + op_params.input1_multiplier = data->input1_multiplier; + op_params.input1_shift = data->input1_shift; + op_params.input2_offset = data->input2_offset; + op_params.input2_multiplier = data->input2_multiplier; + op_params.input2_shift = data->input2_shift; + op_params.output_offset = data->output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + SetActivationParams(data->output_activation_min, + data->output_activation_max, &op_params); + bool need_broadcast = optimized_ops::ProcessBroadcastShapes( + GetTensorShape(input1), GetTensorShape(input2), &op_params); #define TF_LITE_ADD(type, opname) \ - tflite::ArithmeticParams op_params; \ - op_params.left_shift = data->left_shift; \ - op_params.input1_offset = data->input1_offset; \ - op_params.input1_multiplier = data->input1_multiplier; \ - op_params.input1_shift = data->input1_shift; \ - op_params.input2_offset = data->input2_offset; \ - op_params.input2_multiplier = data->input2_multiplier; \ - op_params.input2_shift = data->input2_shift; \ - op_params.output_offset = data->output_offset; \ - op_params.output_multiplier = data->output_multiplier; \ - op_params.output_shift = data->output_shift; \ - SetActivationParams(data->output_activation_min, \ - data->output_activation_max, &op_params); \ type::opname(op_params, GetTensorShape(input1), \ GetTensorData(input1), GetTensorShape(input2), \ GetTensorData(input2), GetTensorShape(output), \ - GetTensorData(output)) - // The quantized version of Add doesn't support activations, so we - // always use BroadcastAdd. + GetTensorData(output)); if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow); + if (need_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow); + } else { + TF_LITE_ADD(reference_ops, Add); + } } else { - TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow); + if (need_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAddFivefold); + } else { + TF_LITE_ADD(optimized_ops, Add); + } } #undef TF_LITE_ADD } else if (output->type == kTfLiteInt16) { diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 00be5a9db83..6f2cd4faab2 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -70,6 +70,7 @@ using reference_ops::LessEqual; using reference_ops::LessEqualWithScaling; using reference_ops::LessWithScaling; using reference_ops::Mean; +using reference_ops::ProcessBroadcastShapes; using reference_ops::RankOneSelect; using reference_ops::Relu1; using reference_ops::Relu6; diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 6d22b2c017a..1bd9129488a 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -100,6 +100,98 @@ gemmlowp::FixedPoint SaturatingSub( namespace reference_ops { +// Return true for broadcast case, false otherwise. +inline bool ProcessBroadcastShapes(const RuntimeShape& shape0, + const RuntimeShape& shape1, + tflite::ArithmeticParams* params) { + const int dims_count = + std::max(shape0.DimensionsCount(), shape1.DimensionsCount()); + + params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast; + RuntimeShape scalar_shape(dims_count, 1); + + auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0); + auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1); + + // Check for "exact" match, implicitly accepting any scalar shapes. + if (extended_shape0 == extended_shape1) { + params->broadcast_category = BroadcastableOpCategory::kNonBroadcast; + return false; + } + + for (int i = dims_count - 1; i >= 0; --i) { + if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) { + continue; + } else if (extended_shape0.Dims(i) == 1) { + params->broadcast_category = + BroadcastableOpCategory::kFirstInputBroadcastsFast; + break; + } else if (extended_shape1.Dims(i) == 1) { + params->broadcast_category = + BroadcastableOpCategory::kSecondInputBroadcastsFast; + break; + } else { + params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast; + break; + } + } + + if (params->broadcast_category != + BroadcastableOpCategory::kFirstInputBroadcastsFast && + params->broadcast_category != + BroadcastableOpCategory::kSecondInputBroadcastsFast) { + return false; + } + + // From this point it is assumed contractually that corresponding dimensions + // in shape0 and shape1 are either (a) equal or (b) one or other equals 1. + const bool swap_inputs = params->broadcast_category == + BroadcastableOpCategory::kSecondInputBroadcastsFast; + const RuntimeShape* shape_a = + swap_inputs ? &extended_shape1 : &extended_shape0; + const RuntimeShape* shape_b = + swap_inputs ? &extended_shape0 : &extended_shape1; + + int i = dims_count - 1; + params->broadcast_shape[0] = 1; + params->broadcast_shape[1] = 1; + params->broadcast_shape[2] = 1; + params->broadcast_shape[3] = 1; + params->broadcast_shape[4] = 1; + // y_0 is greedy: include dims if both or neither equal 1: in other words, + // test for equality rather than (shape_a->Dims(i) != 1). + while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) { + params->broadcast_shape[4] *= shape_b->Dims(i); + --i; + } + // Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b + // that has the unit dimension, the next two loops are not entered. + while (i >= 0 && shape_a->Dims(i) == 1) { + params->broadcast_shape[3] *= shape_b->Dims(i); + --i; + } + while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) { + params->broadcast_shape[2] *= shape_a->Dims(i); + --i; + } + // Here either input_a or input_b has dim of 1 (if i >= 0). + while (i >= 0 && shape_b->Dims(i) == 1) { + params->broadcast_shape[1] *= shape_a->Dims(i); + --i; + } + while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) { + params->broadcast_shape[0] *= shape_b->Dims(i); + --i; + } + + // Rarer case is when the broadcast dimensions cannot be handled by a fivefold + // loop. + if (i >= 0) { + params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast; + } + return true; +} + template int CountLeadingZeros(T integer_input) { static_assert(std::is_unsigned::value,