From e9edfeff0bd20da44e953b22302cf7d5ad99c46d Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Thu, 30 Apr 2020 13:54:11 -0700 Subject: [PATCH] Use persistent buffer in reference svdf. PiperOrigin-RevId: 309293414 Change-Id: I33e9b80fc7f634968a7a11ef68a34550c367b92d --- tensorflow/lite/micro/kernels/svdf.cc | 136 +++++++++++--------------- 1 file changed, 57 insertions(+), 79 deletions(-) diff --git a/tensorflow/lite/micro/kernels/svdf.cc b/tensorflow/lite/micro/kernels/svdf.cc index e2cacf17927..d1f9746679c 100644 --- a/tensorflow/lite/micro/kernels/svdf.cc +++ b/tensorflow/lite/micro/kernels/svdf.cc @@ -331,8 +331,20 @@ constexpr int kInputActivationStateTensor = 4; // Output tensor. constexpr int kOutputTensor = 0; +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + void* data = nullptr; + if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) == + kTfLiteError) { + return nullptr; + } + return data; +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - const auto* params = reinterpret_cast(node->builtin_data); + TFLITE_DCHECK(node->builtin_data != nullptr); + + const auto* params = static_cast(node->builtin_data); // Validate Tensor Inputs (dtype depends on quantization): // [0] = Input, {2, batch_size, input_size} @@ -341,7 +353,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // [3] = Bias (optional), {1, num_units} // [4] = Activation State (variable), // {2, batch_size, memory_size * num_filters} - const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = GetInput(context, node, kWeightsFeatureTensor); @@ -360,8 +371,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int num_units = num_filters / rank; const int memory_size = weights_time->dims->data[1]; - const bool is_full_integer = input->type == kTfLiteInt8; - // Validate Input Tensor: TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 || input->type == kTfLiteInt8); @@ -385,7 +394,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size); // Validate Optional Bias Input Tensor: - if (bias) { + if (bias != nullptr) { TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); } @@ -395,50 +404,52 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1], memory_size * num_filters); - if (is_full_integer) { - TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + if (input->type == kTfLiteInt8) { TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8); TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16); - - if (bias) { + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); + if (bias != nullptr) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); } - TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); - - // Validate Scratch Tensors: - // [0] = (shared - see float block below for usage) - // [1] = Output Temp, int8_t, {2, num_units, batch_size} - // TODO(b/132070898): Scratch values are used as stack variables in - // EvalIntegerSVDF(). - - // Validate output tensor: TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8); - } else { - TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); - // Validate Input Tensor dtypes: + const auto* input_params = + reinterpret_cast(input->quantization.params); + const auto* weights_feature_params = + static_cast( + weights_feature->quantization.params); + const auto* state_params = static_cast( + activation_state->quantization.params); + const auto* weight_time_params = + static_cast( + weights_time->quantization.params); + const auto* output_params = static_cast( + output->quantization.params); + const double effective_scale_1 = static_cast( + input_params->scale->data[0] * weights_feature_params->scale->data[0] / + state_params->scale->data[0]); + const double effective_scale_2 = static_cast( + state_params->scale->data[0] * weight_time_params->scale->data[0] / + output_params->scale->data[0]); + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = static_cast(node->user_data); + + QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a), + &(data->effective_scale_1_b)); + QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a), + &(data->effective_scale_2_b)); + + } else { TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32); - - if (bias) { + if (bias != nullptr) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32); } - - // Validate shared Scratch Tensor: - // [0] = Holds dot-product of time-forward calculations in - // ApplyTimeWeightsBiasAndActivation(): - // float/int32, {2, batch_size, num_filters} - // TODO(b/132070898): Scratch values are used as stack variables in - // EvalIntegerSVDF(). - - // Full-float SVDF only uses the one shared scratch tensor (see above for - // usage). - // TODO(b/132070898): Use input tensor as variable until scratch tensor - // allocation has been implemented. - // TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1); TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); } @@ -458,13 +469,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetVariableInput(context, node, kInputActivationStateTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - const bool is_full_integer = input->type == kTfLiteInt8; - switch (weights_feature->type) { case kTfLiteFloat32: { - // TODO(b/132070898): Use input tensor as variable until scratch tensor - // allocation has been implemented. - // TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); EvalFloatSVDF(context, node, input, weights_feature, weights_time, bias, params, activation_state, output); return kTfLiteOk; @@ -472,43 +478,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } case kTfLiteInt8: { - if (is_full_integer) { - // TODO(b/132070898): Store these values in ::Prepare() instead of - // ::Eval(): - // Calculate effective scales. - OpData op_data; - auto* input_params = reinterpret_cast( - input->quantization.params); - auto* weights_feature_params = - reinterpret_cast( - weights_feature->quantization.params); - auto* state_params = reinterpret_cast( - activation_state->quantization.params); - auto* weight_time_params = reinterpret_cast( - weights_time->quantization.params); - auto* output_params = reinterpret_cast( - output->quantization.params); - const double effective_scale_1 = - static_cast(input_params->scale->data[0] * - weights_feature_params->scale->data[0] / - state_params->scale->data[0]); - const double effective_scale_2 = static_cast( - state_params->scale->data[0] * weight_time_params->scale->data[0] / - output_params->scale->data[0]); - QuantizeMultiplier(effective_scale_1, &op_data.effective_scale_1_a, - &op_data.effective_scale_1_b); - QuantizeMultiplier(effective_scale_2, &op_data.effective_scale_2_a, - &op_data.effective_scale_2_b); - - TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu); - EvalIntegerSVDF( - context, node, input, weights_feature, weights_time, bias, params, - activation_state, output, op_data.effective_scale_1_a, - op_data.effective_scale_1_b, op_data.effective_scale_2_a, - op_data.effective_scale_2_b, input->params.zero_point, - output->params.zero_point); - return kTfLiteOk; - } + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu); + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias, + params, activation_state, output, + data.effective_scale_1_a, data.effective_scale_1_b, + data.effective_scale_2_a, data.effective_scale_2_b, + input->params.zero_point, output->params.zero_point); + return kTfLiteOk; break; } @@ -523,7 +501,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace svdf TfLiteRegistration* Register_SVDF() { - static TfLiteRegistration r = {/*init=*/nullptr, + static TfLiteRegistration r = {/*init=*/svdf::Init, /*free=*/nullptr, /*prepare=*/svdf::Prepare, /*invoke=*/svdf::Eval,