Optimize Quantizied broadcast add op.

PiperOrigin-RevId: 220782252
This commit is contained in:
A. Unique TensorFlower 2018-11-09 04:08:36 -08:00 committed by TensorFlower Gardener
parent a4097ee215
commit 5c07cebb4c
3 changed files with 119 additions and 18 deletions

View File

@ -220,30 +220,38 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input2,
TfLiteTensor* output) {
if (output->type == kTfLiteUInt8) {
tflite::ArithmeticParams op_params;
op_params.left_shift = data->left_shift;
op_params.input1_offset = data->input1_offset;
op_params.input1_multiplier = data->input1_multiplier;
op_params.input1_shift = data->input1_shift;
op_params.input2_offset = data->input2_offset;
op_params.input2_multiplier = data->input2_multiplier;
op_params.input2_shift = data->input2_shift;
op_params.output_offset = data->output_offset;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
SetActivationParams(data->output_activation_min,
data->output_activation_max, &op_params);
bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
GetTensorShape(input1), GetTensorShape(input2), &op_params);
#define TF_LITE_ADD(type, opname) \
tflite::ArithmeticParams op_params; \
op_params.left_shift = data->left_shift; \
op_params.input1_offset = data->input1_offset; \
op_params.input1_multiplier = data->input1_multiplier; \
op_params.input1_shift = data->input1_shift; \
op_params.input2_offset = data->input2_offset; \
op_params.input2_multiplier = data->input2_multiplier; \
op_params.input2_shift = data->input2_shift; \
op_params.output_offset = data->output_offset; \
op_params.output_multiplier = data->output_multiplier; \
op_params.output_shift = data->output_shift; \
SetActivationParams(data->output_activation_min, \
data->output_activation_max, &op_params); \
type::opname(op_params, GetTensorShape(input1), \
GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
GetTensorData<uint8_t>(input2), GetTensorShape(output), \
GetTensorData<uint8_t>(output))
// The quantized version of Add doesn't support activations, so we
// always use BroadcastAdd.
GetTensorData<uint8_t>(output));
if (kernel_type == kReference) {
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow);
if (need_broadcast) {
TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow);
} else {
TF_LITE_ADD(reference_ops, Add);
}
} else {
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow);
if (need_broadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAddFivefold);
} else {
TF_LITE_ADD(optimized_ops, Add);
}
}
#undef TF_LITE_ADD
} else if (output->type == kTfLiteInt16) {

View File

@ -70,6 +70,7 @@ using reference_ops::LessEqual;
using reference_ops::LessEqualWithScaling;
using reference_ops::LessWithScaling;
using reference_ops::Mean;
using reference_ops::ProcessBroadcastShapes;
using reference_ops::RankOneSelect;
using reference_ops::Relu1;
using reference_ops::Relu6;

View File

@ -100,6 +100,98 @@ gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
namespace reference_ops {
// Return true for broadcast case, false otherwise.
inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
const RuntimeShape& shape1,
tflite::ArithmeticParams* params) {
const int dims_count =
std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
RuntimeShape scalar_shape(dims_count, 1);
auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
// Check for "exact" match, implicitly accepting any scalar shapes.
if (extended_shape0 == extended_shape1) {
params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
return false;
}
for (int i = dims_count - 1; i >= 0; --i) {
if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
continue;
} else if (extended_shape0.Dims(i) == 1) {
params->broadcast_category =
BroadcastableOpCategory::kFirstInputBroadcastsFast;
break;
} else if (extended_shape1.Dims(i) == 1) {
params->broadcast_category =
BroadcastableOpCategory::kSecondInputBroadcastsFast;
break;
} else {
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
break;
}
}
if (params->broadcast_category !=
BroadcastableOpCategory::kFirstInputBroadcastsFast &&
params->broadcast_category !=
BroadcastableOpCategory::kSecondInputBroadcastsFast) {
return false;
}
// From this point it is assumed contractually that corresponding dimensions
// in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
const bool swap_inputs = params->broadcast_category ==
BroadcastableOpCategory::kSecondInputBroadcastsFast;
const RuntimeShape* shape_a =
swap_inputs ? &extended_shape1 : &extended_shape0;
const RuntimeShape* shape_b =
swap_inputs ? &extended_shape0 : &extended_shape1;
int i = dims_count - 1;
params->broadcast_shape[0] = 1;
params->broadcast_shape[1] = 1;
params->broadcast_shape[2] = 1;
params->broadcast_shape[3] = 1;
params->broadcast_shape[4] = 1;
// y_0 is greedy: include dims if both or neither equal 1: in other words,
// test for equality rather than (shape_a->Dims(i) != 1).
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
params->broadcast_shape[4] *= shape_b->Dims(i);
--i;
}
// Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b
// that has the unit dimension, the next two loops are not entered.
while (i >= 0 && shape_a->Dims(i) == 1) {
params->broadcast_shape[3] *= shape_b->Dims(i);
--i;
}
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
params->broadcast_shape[2] *= shape_a->Dims(i);
--i;
}
// Here either input_a or input_b has dim of 1 (if i >= 0).
while (i >= 0 && shape_b->Dims(i) == 1) {
params->broadcast_shape[1] *= shape_a->Dims(i);
--i;
}
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
params->broadcast_shape[0] *= shape_b->Dims(i);
--i;
}
// Rarer case is when the broadcast dimensions cannot be handled by a fivefold
// loop.
if (i >= 0) {
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
}
return true;
}
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,