Optimize Quantizied broadcast add op.
PiperOrigin-RevId: 220782252
This commit is contained in:
parent
a4097ee215
commit
5c07cebb4c
@ -220,30 +220,38 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* input2,
|
||||
TfLiteTensor* output) {
|
||||
if (output->type == kTfLiteUInt8) {
|
||||
tflite::ArithmeticParams op_params;
|
||||
op_params.left_shift = data->left_shift;
|
||||
op_params.input1_offset = data->input1_offset;
|
||||
op_params.input1_multiplier = data->input1_multiplier;
|
||||
op_params.input1_shift = data->input1_shift;
|
||||
op_params.input2_offset = data->input2_offset;
|
||||
op_params.input2_multiplier = data->input2_multiplier;
|
||||
op_params.input2_shift = data->input2_shift;
|
||||
op_params.output_offset = data->output_offset;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_shift = data->output_shift;
|
||||
SetActivationParams(data->output_activation_min,
|
||||
data->output_activation_max, &op_params);
|
||||
bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
||||
GetTensorShape(input1), GetTensorShape(input2), &op_params);
|
||||
#define TF_LITE_ADD(type, opname) \
|
||||
tflite::ArithmeticParams op_params; \
|
||||
op_params.left_shift = data->left_shift; \
|
||||
op_params.input1_offset = data->input1_offset; \
|
||||
op_params.input1_multiplier = data->input1_multiplier; \
|
||||
op_params.input1_shift = data->input1_shift; \
|
||||
op_params.input2_offset = data->input2_offset; \
|
||||
op_params.input2_multiplier = data->input2_multiplier; \
|
||||
op_params.input2_shift = data->input2_shift; \
|
||||
op_params.output_offset = data->output_offset; \
|
||||
op_params.output_multiplier = data->output_multiplier; \
|
||||
op_params.output_shift = data->output_shift; \
|
||||
SetActivationParams(data->output_activation_min, \
|
||||
data->output_activation_max, &op_params); \
|
||||
type::opname(op_params, GetTensorShape(input1), \
|
||||
GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
|
||||
GetTensorData<uint8_t>(input2), GetTensorShape(output), \
|
||||
GetTensorData<uint8_t>(output))
|
||||
// The quantized version of Add doesn't support activations, so we
|
||||
// always use BroadcastAdd.
|
||||
GetTensorData<uint8_t>(output));
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow);
|
||||
if (need_broadcast) {
|
||||
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow);
|
||||
} else {
|
||||
TF_LITE_ADD(reference_ops, Add);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow);
|
||||
if (need_broadcast) {
|
||||
TF_LITE_ADD(optimized_ops, BroadcastAddFivefold);
|
||||
} else {
|
||||
TF_LITE_ADD(optimized_ops, Add);
|
||||
}
|
||||
}
|
||||
#undef TF_LITE_ADD
|
||||
} else if (output->type == kTfLiteInt16) {
|
||||
|
@ -70,6 +70,7 @@ using reference_ops::LessEqual;
|
||||
using reference_ops::LessEqualWithScaling;
|
||||
using reference_ops::LessWithScaling;
|
||||
using reference_ops::Mean;
|
||||
using reference_ops::ProcessBroadcastShapes;
|
||||
using reference_ops::RankOneSelect;
|
||||
using reference_ops::Relu1;
|
||||
using reference_ops::Relu6;
|
||||
|
@ -100,6 +100,98 @@ gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
// Return true for broadcast case, false otherwise.
|
||||
inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
|
||||
const RuntimeShape& shape1,
|
||||
tflite::ArithmeticParams* params) {
|
||||
const int dims_count =
|
||||
std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
|
||||
|
||||
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
|
||||
RuntimeShape scalar_shape(dims_count, 1);
|
||||
|
||||
auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
|
||||
auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
|
||||
|
||||
// Check for "exact" match, implicitly accepting any scalar shapes.
|
||||
if (extended_shape0 == extended_shape1) {
|
||||
params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = dims_count - 1; i >= 0; --i) {
|
||||
if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
|
||||
continue;
|
||||
} else if (extended_shape0.Dims(i) == 1) {
|
||||
params->broadcast_category =
|
||||
BroadcastableOpCategory::kFirstInputBroadcastsFast;
|
||||
break;
|
||||
} else if (extended_shape1.Dims(i) == 1) {
|
||||
params->broadcast_category =
|
||||
BroadcastableOpCategory::kSecondInputBroadcastsFast;
|
||||
break;
|
||||
} else {
|
||||
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (params->broadcast_category !=
|
||||
BroadcastableOpCategory::kFirstInputBroadcastsFast &&
|
||||
params->broadcast_category !=
|
||||
BroadcastableOpCategory::kSecondInputBroadcastsFast) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// From this point it is assumed contractually that corresponding dimensions
|
||||
// in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
|
||||
const bool swap_inputs = params->broadcast_category ==
|
||||
BroadcastableOpCategory::kSecondInputBroadcastsFast;
|
||||
const RuntimeShape* shape_a =
|
||||
swap_inputs ? &extended_shape1 : &extended_shape0;
|
||||
const RuntimeShape* shape_b =
|
||||
swap_inputs ? &extended_shape0 : &extended_shape1;
|
||||
|
||||
int i = dims_count - 1;
|
||||
params->broadcast_shape[0] = 1;
|
||||
params->broadcast_shape[1] = 1;
|
||||
params->broadcast_shape[2] = 1;
|
||||
params->broadcast_shape[3] = 1;
|
||||
params->broadcast_shape[4] = 1;
|
||||
// y_0 is greedy: include dims if both or neither equal 1: in other words,
|
||||
// test for equality rather than (shape_a->Dims(i) != 1).
|
||||
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
|
||||
params->broadcast_shape[4] *= shape_b->Dims(i);
|
||||
--i;
|
||||
}
|
||||
// Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b
|
||||
// that has the unit dimension, the next two loops are not entered.
|
||||
while (i >= 0 && shape_a->Dims(i) == 1) {
|
||||
params->broadcast_shape[3] *= shape_b->Dims(i);
|
||||
--i;
|
||||
}
|
||||
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
|
||||
params->broadcast_shape[2] *= shape_a->Dims(i);
|
||||
--i;
|
||||
}
|
||||
// Here either input_a or input_b has dim of 1 (if i >= 0).
|
||||
while (i >= 0 && shape_b->Dims(i) == 1) {
|
||||
params->broadcast_shape[1] *= shape_a->Dims(i);
|
||||
--i;
|
||||
}
|
||||
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
|
||||
params->broadcast_shape[0] *= shape_b->Dims(i);
|
||||
--i;
|
||||
}
|
||||
|
||||
// Rarer case is when the broadcast dimensions cannot be handled by a fivefold
|
||||
// loop.
|
||||
if (i >= 0) {
|
||||
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int CountLeadingZeros(T integer_input) {
|
||||
static_assert(std::is_unsigned<T>::value,
|
||||
|
Loading…
Reference in New Issue
Block a user