Use persistent buffer in xtensa_hifimini/fully_connected.

PiperOrigin-RevId: 309264335
Change-Id: If15693b8cd590a947497c429d476c12109885cd3
This commit is contained in:
Advait Jain 2020-04-30 11:23:18 -07:00 committed by TensorFlower Gardener
parent 7148ced951
commit f81dadf228

View File

@ -33,13 +33,12 @@ namespace micro {
namespace xtensa {
namespace hifimini {
// Int8 optimized:
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& filter_shape,
const int8_t* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
int8_t* output_data) {
void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& input_shape, const int8_t* input_data,
const RuntimeShape& filter_shape, const int8_t* filter_data,
const RuntimeShape& bias_shape, const int32* bias_data,
const RuntimeShape& output_shape, int8_t* output_data) {
// TODO(b/154032858): Investigate removing extra copies.
const int32 input_offset = params.input_offset;
const int32 filter_offset = params.weights_offset;
const int32 output_offset = params.output_offset;
@ -142,72 +141,69 @@ constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
// This size will work for both the hotword (5) and ambient music (2):
constexpr int kMaxOpDataSize = 7;
static int op_data_counter = 0;
static OpData kStaticOpData[kMaxOpDataSize];
TfLiteStatus CalculateOpData(TfLiteContext* context,
TfLiteFullyConnectedParams* params,
TfLiteFusedActivation activation,
TfLiteType data_type, const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output,
OpData* data) {
TfLiteStatus status = kTfLiteOk;
if (data_type != kTfLiteFloat32) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
int exponent;
xtensa::hifimini::QuantizeMultiplier(real_multiplier,
&data->output_multiplier, &exponent);
data->output_shift = -exponent;
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->output_activation_min,
&data->output_activation_max));
}
return status;
TFLITE_DCHECK(data_type != kTfLiteFloat32);
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
int exponent;
xtensa::hifimini::QuantizeMultiplier(real_multiplier,
&data->output_multiplier, &exponent);
data->output_shift = -exponent;
return CalculateActivationRangeQuantized(context, activation, output,
&data->output_activation_min,
&data->output_activation_max);
}
} // namespace
void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
void* data = nullptr;
if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
kTfLiteError) {
return nullptr;
}
return data;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params =
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
const auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteType data_type = input->type;
// TODO(b/132070898): Use statically slotted OpData structures until a
// scratch memory API is ready.
OpData* op_data = &kStaticOpData[op_data_counter++];
node->user_data = op_data;
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
filter, bias, output, op_data));
return kTfLiteOk;
return CalculateOpData(context, params->activation, input->type, input,
filter, bias, output, data);
}
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input,
const OpData& data, const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
// TODO(b/154032858): Investigate removing extra copies.
FullyConnectedParams op_params;
op_params.input_offset = -input->params.zero_point;
op_params.weights_offset = -filter->params.zero_point;
op_params.output_offset = output->params.zero_point;
op_params.output_multiplier = data->output_multiplier;
op_params.output_multiplier = data.output_multiplier;
// TODO(b/138810107): Figure out whether output shift should be inverted
op_params.output_shift = -data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
op_params.output_shift = -data.output_shift;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
xtensa::hifimini::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
@ -218,33 +214,23 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (filter->type) { // Already know in/out types are same.
case kTfLiteInt8:
return EvalQuantizedInt8(context, node, params, op_data, input, filter,
bias, output);
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(filter->type), filter->type);
return kTfLiteError;
}
return kTfLiteOk;
TFLITE_DCHECK(filter->type == kTfLiteInt8);
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
}
} // namespace fully_connected
TfLiteRegistration* Register_FULLY_CONNECTED() {
static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/fully_connected::Free,
static TfLiteRegistration r = {/*init=*/fully_connected::Init,
/*free=*/nullptr,
/*prepare=*/fully_connected::Prepare,
/*invoke=*/fully_connected::Eval,
/*profiling_string=*/nullptr,