Implement unidirectional_sequence_lstm runtime by a separate branch of EvalInteger8x8_16.

PiperOrigin-RevId: 338074032
Change-Id: I39a4ed4588b554580b2aa922b22b57fe0ca9730a
This commit is contained in:
A. Unique TensorFlower 2020-10-20 09:25:16 -07:00 committed by TensorFlower Gardener
parent 9b5e180a69
commit df3ad536b9
7 changed files with 1544 additions and 182 deletions

View File

@ -2102,10 +2102,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_gate_bias, output_gate_bias, projection_weights,
projection_bias, params, &op_data->integer_lstm_param,
output_state, cell_state, output, scratch0, scratch1, scratch2,
scratch3, scratch4, scratch5,
CpuBackendContext::GetFromContext(context));
projection_bias, params, /*forward_sequence=*/true,
/*time_major=*/true, &op_data->integer_lstm_param, output_state,
cell_state, output, scratch0, scratch1, scratch2, scratch3,
scratch4, scratch5, CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0;
TF_LITE_ENSURE_OK(context,

View File

@ -1412,8 +1412,10 @@ inline void LstmStepInteger8x8_16(
TFLITE_DCHECK(input_to_input_effective_bias);
TFLITE_DCHECK(recurrent_to_input_effective_bias);
}
const bool use_projection = (projection_weight_ptr != nullptr);
if (use_projection) {
TFLITE_DCHECK(projection_effective_bias);
}
if (!use_cifg) {
// Calculate the input gate. (If not CIFG.)
CalculateLstmGateInteger8x8_16(
@ -1479,7 +1481,7 @@ inline void LstmStepInteger8x8_16(
quantized_proj_clip, output_state_ptr, context, scratch0, scratch4,
scratch5);
// Copy output state to the output. Note that unlike float or hybrid, output
// is always contigous.
// is always contiguous.
std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
}
@ -2177,7 +2179,7 @@ TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
@ -2190,8 +2192,8 @@ TfLiteStatus EvalInteger8x8_16(
max_time = 1;
n_batch = input->dims->data[0];
} else {
max_time = input->dims->data[0];
n_batch = input->dims->data[1];
max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
}
// n_cell and n_output will be the same size when there is no projection.
@ -2204,13 +2206,15 @@ TfLiteStatus EvalInteger8x8_16(
// Get params for time/batch/sequence.
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
if (time_major) {
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
const int t_rel = t;
int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
const int8_t* input_ptr =
GetTensorData<int8_t>(input) + t_rel * input_step;
LstmStepInteger8x8_16(
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
integer_lstm_param->effective_input_to_input_scale_a,
@ -2268,7 +2272,8 @@ TfLiteStatus EvalInteger8x8_16(
GetTensorData<int32_t>(cell_gate_bias),
GetTensorData<int32_t>(output_gate_bias),
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip, integer_lstm_param->cell_scale,
integer_lstm_param->quantized_proj_clip,
integer_lstm_param->cell_scale,
integer_lstm_param->input_variance_guard,
integer_lstm_param->forget_variance_guard,
integer_lstm_param->cell_variance_guard,
@ -2282,13 +2287,113 @@ TfLiteStatus EvalInteger8x8_16(
integer_lstm_param->input_to_input_effective_bias.get(),
integer_lstm_param->recurrent_to_input_effective_bias.get(),
integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
n_input, n_output, GetTensorData<int8_t>(output_state), output_state_zp,
GetTensorData<int16_t>(cell_state), output_ptr,
n_input, n_output, GetTensorData<int8_t>(output_state),
output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1),
GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5),
context);
}
} else {
for (int b = 0; b < n_batch; b++) {
const int input_step = n_input;
const int output_step = output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const int time_offset = b * max_time + t_rel;
const int8_t* input_ptr =
GetTensorData<int8_t>(input) + time_offset * input_step;
int8_t* output_ptr =
GetTensorData<int8_t>(output) + time_offset * output_step;
// Offset the {output,cell}_state pointers to the right batch.
int8_t* output_state_ptr =
GetTensorData<int8_t>(output_state) + b * output_batch_leading_dim;
int16_t* cell_state_ptr =
GetTensorData<int16_t>(cell_state) + b * n_cell;
LstmStepInteger8x8_16(
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
integer_lstm_param->effective_input_to_input_scale_a,
integer_lstm_param->effective_input_to_input_scale_b,
GetTensorData<int8_t>(input_to_forget_weights),
integer_lstm_param->effective_input_to_forget_scale_a,
integer_lstm_param->effective_input_to_forget_scale_b,
GetTensorData<int8_t>(input_to_cell_weights),
integer_lstm_param->effective_input_to_cell_scale_a,
integer_lstm_param->effective_input_to_cell_scale_b,
GetTensorData<int8_t>(input_to_output_weights),
integer_lstm_param->effective_input_to_output_scale_a,
integer_lstm_param->effective_input_to_output_scale_b,
GetTensorData<int8_t>(recurrent_to_input_weights),
integer_lstm_param->effective_recurrent_to_input_scale_a,
integer_lstm_param->effective_recurrent_to_input_scale_b,
GetTensorData<int8_t>(recurrent_to_forget_weights),
integer_lstm_param->effective_recurrent_to_forget_scale_a,
integer_lstm_param->effective_recurrent_to_forget_scale_b,
GetTensorData<int8_t>(recurrent_to_cell_weights),
integer_lstm_param->effective_recurrent_to_cell_scale_a,
integer_lstm_param->effective_recurrent_to_cell_scale_b,
GetTensorData<int8_t>(recurrent_to_output_weights),
integer_lstm_param->effective_recurrent_to_output_scale_a,
integer_lstm_param->effective_recurrent_to_output_scale_b,
GetTensorData<int16_t>(cell_to_input_weights),
integer_lstm_param->effective_cell_to_input_scale_a,
integer_lstm_param->effective_cell_to_input_scale_b,
GetTensorData<int16_t>(cell_to_forget_weights),
integer_lstm_param->effective_cell_to_forget_scale_a,
integer_lstm_param->effective_cell_to_forget_scale_b,
GetTensorData<int16_t>(cell_to_output_weights),
integer_lstm_param->effective_cell_to_output_scale_a,
integer_lstm_param->effective_cell_to_output_scale_b,
GetTensorData<int8_t>(projection_weights),
integer_lstm_param->effective_proj_scale_a,
integer_lstm_param->effective_proj_scale_b,
integer_lstm_param->hidden_zp,
integer_lstm_param->effective_hidden_scale_a,
integer_lstm_param->effective_hidden_scale_b,
GetTensorData<int16_t>(input_layer_norm_coefficients),
integer_lstm_param->layer_norm_input_scale_a,
integer_lstm_param->layer_norm_input_scale_b,
GetTensorData<int16_t>(forget_layer_norm_coefficients),
integer_lstm_param->layer_norm_forget_scale_a,
integer_lstm_param->layer_norm_forget_scale_b,
GetTensorData<int16_t>(cell_layer_norm_coefficients),
integer_lstm_param->layer_norm_cell_scale_a,
integer_lstm_param->layer_norm_cell_scale_b,
GetTensorData<int16_t>(output_layer_norm_coefficients),
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
GetTensorData<int32_t>(input_gate_bias),
GetTensorData<int32_t>(forget_gate_bias),
GetTensorData<int32_t>(cell_gate_bias),
GetTensorData<int32_t>(output_gate_bias),
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip,
integer_lstm_param->cell_scale,
integer_lstm_param->input_variance_guard,
integer_lstm_param->forget_variance_guard,
integer_lstm_param->cell_variance_guard,
integer_lstm_param->output_variance_guard,
integer_lstm_param->input_to_forget_effective_bias.get(),
integer_lstm_param->recurrent_to_forget_effective_bias.get(),
integer_lstm_param->input_to_cell_effective_bias.get(),
integer_lstm_param->recurrent_to_cell_effective_bias.get(),
integer_lstm_param->input_to_output_effective_bias.get(),
integer_lstm_param->recurrent_to_output_effective_bias.get(),
integer_lstm_param->input_to_input_effective_bias.get(),
integer_lstm_param->recurrent_to_input_effective_bias.get(),
integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
n_cell, n_input, n_output, output_state_ptr, output_state_zp,
cell_state_ptr, output_ptr, GetTensorData<int16_t>(scratch0),
GetTensorData<int16_t>(scratch1), GetTensorData<int16_t>(scratch2),
GetTensorData<int16_t>(scratch3), GetTensorData<int8_t>(scratch4),
GetTensorData<int32_t>(scratch5), context);
}
}
}
return kTfLiteOk;
}

View File

@ -188,7 +188,7 @@ TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,

View File

@ -617,8 +617,9 @@ void TestOneFullyQuantizedLSTM() {
one_parameter.GetOutputLayerNorm(), one_parameter.GetInputBias(),
one_parameter.GetForgetBias(), one_parameter.GetCellBias(),
one_parameter.GetOutputBias(), one_parameter.GetProjection(),
one_parameter.GetProjectionBias(), nullptr, param, activation, cell,
output, one_parameter.GetScratch0(), one_parameter.GetScratch1(),
one_parameter.GetProjectionBias(), nullptr, /*forward_sequence=*/true,
/*time_major=*/true, param, activation, cell, output,
one_parameter.GetScratch0(), one_parameter.GetScratch1(),
one_parameter.GetScratch2(), one_parameter.GetScratch3(),
one_parameter.GetScratch4(), one_parameter.GetScratch5(), &context);

View File

@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/lstm_eval.h"
@ -31,15 +33,350 @@ namespace tflite {
namespace ops {
namespace builtin {
namespace unidirectional_sequence_lstm {
namespace {
struct OpData {
// If the lstm is layer norm.
bool is_layer_norm_lstm;
bool use_layer_norm;
// The scratch tensor index.
int scratch_tensor_index;
bool compute_row_sums = false;
lstm_eval::IntegerLstmParameter integer_lstm_param;
};
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
TfLiteContext* context, TfLiteNode* node,
lstm_eval::IntegerLstmParameter* integer_lstm_param) {
// Calculate quantized clip for projection and cell.
const auto* params =
static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
const float cell_clip = params->cell_clip;
const float proj_clip = params->proj_clip;
const TfLiteTensor* cell_state =
GetVariableInput(context, node, lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
TfLiteTensor* output_tensor;
TF_LITE_ENSURE_OK(
context,
GetOutputSafe(context, node, lstm::full::kOutputTensor, &output_tensor));
auto* cell_state_params =
static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
auto* proj_params = static_cast<TfLiteAffineQuantization*>(
output_tensor->quantization.params);
if (cell_clip > 0.0) {
integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
32767.0f));
} else {
integer_lstm_param->quantized_cell_clip = 0;
}
if (proj_clip > 0.0) {
integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
} else {
integer_lstm_param->quantized_proj_clip = 0;
}
// Calculate effective scales.
OpData* op_data = static_cast<OpData*>(node->user_data);
const bool use_layer_norm = op_data->use_layer_norm;
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
lstm::full::kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToInputWeightsTensor);
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToForgetWeightsTensor);
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToOutputWeightsTensor);
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
context, node, lstm::full::kInputLayerNormCoefficientsTensor);
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
context, node, lstm::full::kCellLayerNormCoefficientsTensor);
const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor);
TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
const bool use_projection = (projection_weights != nullptr);
// Get intermediate scales and zero points.
std::vector<float> intermediate_scale;
std::vector<int32> intermediate_zp;
for (int i = 0; i < 4; ++i) {
if (use_layer_norm) {
TfLiteTensor* intermediate;
TF_LITE_ENSURE_OK(context,
GetIntermediatesSafe(context, node, i, &intermediate));
auto* params = static_cast<TfLiteAffineQuantization*>(
intermediate->quantization.params);
intermediate_scale.push_back(params->scale->data[0]);
intermediate_zp.push_back(params->zero_point->data[0]);
} else {
// Q3.12 for activation functions.
intermediate_scale.push_back(std::pow(2, -12));
intermediate_zp.push_back(0);
}
}
// In the absense of projection, hidden becomes otuput and this intermediate
// is ignored.
TfLiteTensor* hidden;
TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
auto* hidden_params =
static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
intermediate_scale.push_back(hidden_params->scale->data[0]);
intermediate_zp.push_back(hidden_params->zero_point->data[0]);
// Scales.
const float default_scale = 1.0;
float input_scale = default_scale;
float input_to_input_weight_scale = default_scale;
float recurrent_to_input_weight_scale = default_scale;
float cell_to_input_weight_scale = default_scale;
float input_to_forget_weight_scale = default_scale;
float recurrent_to_forget_weight_scale = default_scale;
float cell_to_forget_weight_scale = default_scale;
float input_to_cell_weight_scale = default_scale;
float recurrent_to_cell_weight_scale = default_scale;
float input_to_output_weight_scale = default_scale;
float recurrent_to_output_weight_scale = default_scale;
float cell_to_output_weight_scale = default_scale;
float projection_weight_scale = default_scale;
float layer_norm_input_scale = default_scale;
float layer_norm_forget_scale = default_scale;
float layer_norm_cell_scale = default_scale;
float layer_norm_output_scale = default_scale;
float output_state_scale = default_scale;
int cell_scale = 1;
// Effective scales.
float effective_input_to_input_scale = default_scale;
float effective_recurrent_to_input_scale = default_scale;
float effective_cell_to_input_scale = default_scale;
float effective_input_to_forget_scale = default_scale;
float effective_recurrent_to_forget_scale = default_scale;
float effective_cell_to_forget_scale = default_scale;
float effective_input_to_cell_scale = default_scale;
float effective_recurrent_to_cell_scale = default_scale;
float effective_input_to_output_scale = default_scale;
float effective_recurrent_to_output_scale = default_scale;
float effective_cell_to_output_scale = default_scale;
float effective_proj_scale = default_scale;
float effective_hidden_scale = default_scale;
// Populate scales.
if (!use_cifg) {
input_to_input_weight_scale = input_to_input_weights->params.scale;
recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
}
if (use_peephole) {
if (!use_cifg) {
cell_to_input_weight_scale = cell_to_input_weights->params.scale;
}
cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
cell_to_output_weight_scale = cell_to_output_weights->params.scale;
}
if (use_layer_norm) {
if (!use_cifg) {
layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
}
layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
}
if (use_projection) {
projection_weight_scale = projection_weights->params.scale;
}
output_state_scale = output_state->params.scale;
input_to_forget_weight_scale = input_to_forget_weights->params.scale;
input_to_cell_weight_scale = input_to_cell_weights->params.scale;
input_to_output_weight_scale = input_to_output_weights->params.scale;
recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
// Check cell state (already used above)
TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
// TF_LITE_ENSURE(context, cell_scale <= -9);
integer_lstm_param->cell_scale = cell_scale;
input_scale = input->params.scale;
// Calculate effective scales.
if (!use_cifg) {
effective_input_to_input_scale =
input_to_input_weight_scale * input_scale / intermediate_scale[0];
effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
output_state_scale /
intermediate_scale[0];
}
effective_input_to_forget_scale =
input_to_forget_weight_scale * input_scale / intermediate_scale[1];
effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
output_state_scale /
intermediate_scale[1];
effective_input_to_cell_scale =
input_to_cell_weight_scale * input_scale / intermediate_scale[2];
effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
output_state_scale /
intermediate_scale[2];
effective_input_to_output_scale =
input_to_output_weight_scale * input_scale / intermediate_scale[3];
effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
output_state_scale /
intermediate_scale[3];
effective_hidden_scale =
std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
effective_proj_scale =
projection_weight_scale * intermediate_scale[4] / output_state_scale;
if (use_peephole) {
if (!use_cifg) {
effective_cell_to_input_scale = std::pow(2, cell_scale) * // NOLINT
cell_to_input_weight_scale /
intermediate_scale[0];
}
effective_cell_to_forget_scale = std::pow(2, cell_scale) * // NOLINT
cell_to_forget_weight_scale /
intermediate_scale[1];
effective_cell_to_output_scale = std::pow(2, cell_scale) * // NOLINT
cell_to_output_weight_scale /
intermediate_scale[3];
}
// Decompose scales.
QuantizeMultiplier(effective_input_to_input_scale,
&integer_lstm_param->effective_input_to_input_scale_a,
&integer_lstm_param->effective_input_to_input_scale_b);
QuantizeMultiplier(effective_recurrent_to_input_scale,
&integer_lstm_param->effective_recurrent_to_input_scale_a,
&integer_lstm_param->effective_recurrent_to_input_scale_b);
QuantizeMultiplier(effective_cell_to_input_scale,
&integer_lstm_param->effective_cell_to_input_scale_a,
&integer_lstm_param->effective_cell_to_input_scale_b);
QuantizeMultiplier(effective_input_to_forget_scale,
&integer_lstm_param->effective_input_to_forget_scale_a,
&integer_lstm_param->effective_input_to_forget_scale_b);
QuantizeMultiplier(
effective_recurrent_to_forget_scale,
&integer_lstm_param->effective_recurrent_to_forget_scale_a,
&integer_lstm_param->effective_recurrent_to_forget_scale_b);
QuantizeMultiplier(effective_cell_to_forget_scale,
&integer_lstm_param->effective_cell_to_forget_scale_a,
&integer_lstm_param->effective_cell_to_forget_scale_b);
QuantizeMultiplier(effective_input_to_cell_scale,
&integer_lstm_param->effective_input_to_cell_scale_a,
&integer_lstm_param->effective_input_to_cell_scale_b);
QuantizeMultiplier(effective_recurrent_to_cell_scale,
&integer_lstm_param->effective_recurrent_to_cell_scale_a,
&integer_lstm_param->effective_recurrent_to_cell_scale_b);
QuantizeMultiplier(effective_input_to_output_scale,
&integer_lstm_param->effective_input_to_output_scale_a,
&integer_lstm_param->effective_input_to_output_scale_b);
QuantizeMultiplier(
effective_recurrent_to_output_scale,
&integer_lstm_param->effective_recurrent_to_output_scale_a,
&integer_lstm_param->effective_recurrent_to_output_scale_b);
QuantizeMultiplier(effective_cell_to_output_scale,
&integer_lstm_param->effective_cell_to_output_scale_a,
&integer_lstm_param->effective_cell_to_output_scale_b);
QuantizeMultiplier(effective_proj_scale,
&integer_lstm_param->effective_proj_scale_a,
&integer_lstm_param->effective_proj_scale_b);
QuantizeMultiplier(effective_hidden_scale,
&integer_lstm_param->effective_hidden_scale_a,
&integer_lstm_param->effective_hidden_scale_b);
QuantizeMultiplier(layer_norm_input_scale,
&integer_lstm_param->layer_norm_input_scale_a,
&integer_lstm_param->layer_norm_input_scale_b);
QuantizeMultiplier(layer_norm_forget_scale,
&integer_lstm_param->layer_norm_forget_scale_a,
&integer_lstm_param->layer_norm_forget_scale_b);
QuantizeMultiplier(layer_norm_cell_scale,
&integer_lstm_param->layer_norm_cell_scale_a,
&integer_lstm_param->layer_norm_cell_scale_b);
QuantizeMultiplier(layer_norm_output_scale,
&integer_lstm_param->layer_norm_output_scale_a,
&integer_lstm_param->layer_norm_output_scale_b);
integer_lstm_param->hidden_zp = intermediate_zp[4];
// 10000 is used to make sure the kernel logic does not overflow.
if (!use_cifg) {
integer_lstm_param->input_variance_guard =
std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale));
}
integer_lstm_param->forget_variance_guard =
std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale));
integer_lstm_param->cell_variance_guard =
std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale));
integer_lstm_param->output_variance_guard =
std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale));
return kTfLiteOk;
}
} // namespace
// Temporary tensors
enum TemporaryTensor {
kScratchBuffer = 0,
@ -72,7 +409,7 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell,
bool is_layer_norm_lstm) {
bool use_layer_norm, bool is_integer) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
@ -151,6 +488,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (cell_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_input_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
@ -158,6 +498,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (cell_to_forget_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_forget_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
@ -165,6 +508,9 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (cell_to_output_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(
context, cell_to_output_weights->type,
is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
}
// Making sure the peephole weights are there all or none.
@ -186,6 +532,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
} else {
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
}
}
const TfLiteTensor* forget_gate_bias;
@ -194,6 +545,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
&forget_gate_bias));
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* cell_gate_bias;
TF_LITE_ENSURE_OK(context,
@ -201,6 +557,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
&cell_gate_bias));
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* output_gate_bias;
TF_LITE_ENSURE_OK(
@ -208,6 +569,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
&output_gate_bias));
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor);
@ -222,6 +588,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (projection_bias != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
}
}
// Making sure the projection tensors are consistent:
@ -233,7 +604,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
((projection_weights != nullptr) || (projection_bias == nullptr));
TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
if (is_layer_norm_lstm) {
if (use_layer_norm) {
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
context, node, lstm::full::kInputLayerNormCoefficientsTensor);
if (use_cifg) {
@ -243,9 +614,14 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
kTfLiteFloat32);
}
}
const TfLiteTensor* forget_layer_norm_coefficients;
TF_LITE_ENSURE_OK(
@ -255,8 +631,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteFloat32);
}
const TfLiteTensor* cell_layer_norm_coefficients;
TF_LITE_ENSURE_OK(context,
@ -266,8 +647,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteFloat32);
}
const TfLiteTensor* output_layer_norm_coefficients;
TF_LITE_ENSURE_OK(
@ -277,13 +663,185 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
kTfLiteInt16);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
kTfLiteFloat32);
}
}
return kTfLiteOk;
}
TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
TfLiteContext* context, int32_t zero_point,
const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
std::unique_ptr<int32_t[]>* output) {
if (weight_tensor == nullptr) {
return kTfLiteOk;
}
const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
const int row = weight_shape.Dims(0);
const int col = weight_shape.Dims(1);
output->reset(new int32_t[row]);
if (bias_tensor == nullptr) {
memset(output->get(), 0, row * sizeof(int32_t));
} else {
const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
memcpy(output->get(), bias, row * sizeof(int32_t));
}
if (zero_point != 0) {
const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col,
output->get());
}
return kTfLiteOk;
}
TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
OpData* op_data,
TfLiteNode* node) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
const TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
const int32_t input_zero_point = -input->params.zero_point;
const int32_t output_state_zero_point = -output_state->params.zero_point;
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
&input_to_forget_weights));
const TfLiteTensor* input_to_cell_weights;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
lstm::full::kInputToCellWeightsTensor,
&input_to_cell_weights));
const TfLiteTensor* input_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
&input_to_output_weights));
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
&recurrent_to_forget_weights));
const TfLiteTensor* recurrent_to_cell_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
&recurrent_to_cell_weights));
const TfLiteTensor* recurrent_to_output_weights;
TF_LITE_ENSURE_OK(
context,
GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
&recurrent_to_output_weights));
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor);
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
lstm_eval::IntegerLstmParameter* integer_lstm_params =
&op_data->integer_lstm_param;
const TfLiteTensor* intermediate =
&context->tensors[node->intermediates->data[4]];
const auto* params =
static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
const int32_t hidden_zp = params->zero_point->data[0];
// Get bias and perform zero point calculation.
// When there is layer normalization, the gate bias does not apply to matmul
// directly:
// y = ln(w * x + w * r + w * c) + b.
const bool is_layer_norm = op_data->use_layer_norm;
// Forget gate.
const TfLiteTensor* forget_gate_bias =
is_layer_norm
? nullptr
: GetInput(context, node, lstm::full::kForgetGateBiasTensor);
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, input_zero_point, input_to_forget_weights, forget_gate_bias,
&(integer_lstm_params->input_to_forget_effective_bias)));
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, output_state_zero_point, recurrent_to_forget_weights,
nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
// Modulation gate.
const TfLiteTensor* cell_gate_bias =
is_layer_norm ? nullptr
: GetInput(context, node, lstm::full::kCellGateBiasTensor);
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, input_zero_point, input_to_cell_weights, cell_gate_bias,
&(integer_lstm_params->input_to_cell_effective_bias)));
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
&(integer_lstm_params->recurrent_to_cell_effective_bias)));
// Output gate.
const TfLiteTensor* output_gate_bias =
is_layer_norm
? nullptr
: GetInput(context, node, lstm::full::kOutputGateBiasTensor);
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, input_zero_point, input_to_output_weights, output_gate_bias,
&(integer_lstm_params->input_to_output_effective_bias)));
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, output_state_zero_point, recurrent_to_output_weights,
nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
// Input gate. The calculation is only meaningful for non-cifg case.
const TfLiteTensor* input_gate_bias =
is_layer_norm ? nullptr
: GetInput(context, node, lstm::full::kInputGateBiasTensor);
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, input_zero_point, input_to_input_weights, input_gate_bias,
&(integer_lstm_params->input_to_input_effective_bias)));
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, output_state_zero_point, recurrent_to_input_weights, nullptr,
&(integer_lstm_params->recurrent_to_input_effective_bias)));
// Projection bias. The calculation is only meaningful for with projection.
TF_LITE_ENSURE_OK(context,
PrecomputeZeroPointTimesWeightWithBias(
context, hidden_zp, projection_weights, projection_bias,
&(integer_lstm_params->projection_effective_bias)));
return kTfLiteOk;
}
// Resize the output and state tensors based on the sizes of the input tensors.
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
@ -292,18 +850,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
bool is_layer_norm_lstm = false;
bool use_layer_norm = false;
if (node->inputs->size == 24) {
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
if (forget_layer_norm_coefficients == nullptr) {
is_layer_norm_lstm = false;
use_layer_norm = false;
} else {
is_layer_norm_lstm = true;
use_layer_norm = true;
}
} else if (node->inputs->size == 20) {
// This is deprecated and is only kept here for backward compatibility.
is_layer_norm_lstm = false;
use_layer_norm = false;
} else {
context->ReportError(
context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
@ -311,14 +869,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
op_data->is_layer_norm_lstm = is_layer_norm_lstm;
op_data->use_layer_norm = use_layer_norm;
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
const bool is_integer = input->type == kTfLiteInt8;
TF_LITE_ENSURE(context, input->dims->size > 1);
const auto* params =
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
@ -347,9 +905,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(context,
CheckInputTensorDimensions(context, node, n_input, n_output,
n_cell, is_layer_norm_lstm));
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_output,
n_cell, use_layer_norm, is_integer));
// Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output;
@ -375,9 +933,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
if (is_integer) {
const int num_intermediate_tensors = node->intermediates->size;
TF_LITE_ENSURE(context, num_intermediate_tensors == 5);
}
TfLiteIntArrayFree(node->temporaries);
if (IsHybridOp(input, input_to_output_weights)) {
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
} else if (is_integer) {
node->temporaries = TfLiteIntArrayCreate(6);
} else {
node->temporaries = TfLiteIntArrayCreate(1);
}
@ -590,6 +1155,50 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
}
if (is_integer) {
// Integer UnidirectionalSequenceLSTM prepare function for 8x8->16.
// This code path needs 5 intermediate tensors per Op.
// Populate quantization parameters.
PopulateQuantizedLstmParams8x8_16(context, node,
&op_data->integer_lstm_param);
// Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
// and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
// buffer with size n_batch * n_cell.
//
// Handle cifg case as well, which might save one buffer.
for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
node->temporaries->data[scratch_index] =
op_data->scratch_tensor_index + scratch_index;
TfLiteTensor* scratch_tensor;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, scratch_index,
&scratch_tensor));
scratch_tensor->type = kTfLiteInt16;
if (scratch_index == 4) {
scratch_tensor->type = kTfLiteInt8;
} else if (scratch_index == 5) {
scratch_tensor->type = kTfLiteInt32;
}
scratch_tensor->allocation_type = kTfLiteArenaRw;
const int scratch_dimension[2] = {n_batch, n_cell};
if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
scratch_dimension)) {
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
scratch_buffer_size->data[0] = n_batch;
scratch_buffer_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, scratch_tensor,
scratch_buffer_size));
}
}
// Populate precomputed zp * weight.
TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
context, op_data, node));
}
return kTfLiteOk;
}
@ -598,7 +1207,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data);
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
const bool use_layer_norm = op_data->use_layer_norm;
const bool time_major = params->time_major;
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
@ -666,11 +1275,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
TFLITE_DCHECK(output_state != nullptr);
@ -679,23 +1283,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(cell_state != nullptr);
const TfLiteTensor* input_layer_norm_coefficients =
is_layer_norm_lstm
use_layer_norm
? GetOptionalInputTensor(
context, node, lstm::full::kInputLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* forget_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node,
use_layer_norm ? GetInput(context, node,
lstm::full::kForgetLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* cell_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node,
use_layer_norm ? GetInput(context, node,
lstm::full::kCellLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* output_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node,
use_layer_norm ? GetInput(context, node,
lstm::full::kOutputLayerNormCoefficientsTensor)
: nullptr;
@ -712,6 +1313,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
&scratch_buffer));
return lstm_eval::EvalFloat(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
@ -733,6 +1338,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
case kTfLiteUInt8:
case kTfLiteInt8: {
const bool is_hybrid = input->type == kTfLiteFloat32;
if (is_hybrid) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer;
TF_LITE_ENSURE_OK(
context,
GetTemporarySafe(context, node, kScratchBuffer, &scratch_buffer));
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* row_sums;
TF_LITE_ENSURE_OK(context,
@ -751,10 +1364,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recurrent_to_cell_weights,
/*recurrent_to_cell_weights_ledger*/ nullptr,
recurrent_to_output_weights,
/*recurrent_to_output_weights_ledger*/ nullptr, cell_to_input_weights,
cell_to_forget_weights, cell_to_output_weights,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
/*recurrent_to_output_weights_ledger*/ nullptr,
cell_to_input_weights, cell_to_forget_weights,
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,
@ -780,6 +1394,40 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
row_sums_size, &op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 0, &scratch0));
TfLiteTensor* scratch1;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 1, &scratch1));
TfLiteTensor* scratch2;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 2, &scratch2));
TfLiteTensor* scratch3;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 3, &scratch3));
TfLiteTensor* scratch4;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 4, &scratch4));
TfLiteTensor* scratch5;
TF_LITE_ENSURE_OK(context,
GetTemporarySafe(context, node, 5, &scratch5));
return lstm_eval::EvalInteger8x8_16(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights,
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_gate_bias, output_gate_bias, projection_weights,
projection_bias, &lstm_params, /*forward_sequence=*/true,
time_major, &op_data->integer_lstm_param, output_state, cell_state,
output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5,
CpuBackendContext::GetFromContext(context));
}
}
default:
TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",

View File

@ -2739,6 +2739,611 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
class UnidirectionalSequenceLSTMIntegerOpModel : public SingleOpModel {
public:
UnidirectionalSequenceLSTMIntegerOpModel(
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
bool time_major, bool use_cifg, bool use_peephole,
bool use_projection_weights, bool use_projection_bias,
bool use_layer_norm, bool use_8x8_8_implementation,
const std::vector<std::pair<float, float>>& ranges,
const std::vector<std::pair<float, int>>& intermediates,
bool asymmetric_quantize_inputs = false)
: n_input_(n_input), n_output_(n_output) {
input_ = AddInput({TensorType_INT8,
{sequence_length, n_batch, n_input},
ranges[0].first,
ranges[0].second});
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
input_to_input_weights_ = AddInput({TensorType_INT8,
{n_cell, n_input},
ranges[1].first,
ranges[1].second});
}
input_to_forget_weights_ = AddInput({TensorType_INT8,
{n_cell, n_input},
ranges[2].first,
ranges[2].second});
input_to_cell_weights_ = AddInput({TensorType_INT8,
{n_cell, n_input},
ranges[3].first,
ranges[3].second});
input_to_output_weights_ = AddInput({TensorType_INT8,
{n_cell, n_input},
ranges[4].first,
ranges[4].second});
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
recurrent_to_input_weights_ = AddInput({TensorType_INT8,
{n_cell, n_output},
ranges[5].first,
ranges[5].second});
}
recurrent_to_forget_weights_ = AddInput({TensorType_INT8,
{n_cell, n_output},
ranges[6].first,
ranges[6].second});
recurrent_to_cell_weights_ = AddInput({TensorType_INT8,
{n_cell, n_output},
ranges[7].first,
ranges[7].second});
recurrent_to_output_weights_ = AddInput({TensorType_INT8,
{n_cell, n_output},
ranges[8].first,
ranges[8].second});
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
cell_to_input_weights_ = AddInput(
{TensorType_INT16, {n_cell}, ranges[9].first, ranges[9].second});
}
cell_to_forget_weights_ = AddInput(
{TensorType_INT16, {n_cell}, ranges[10].first, ranges[10].second});
cell_to_output_weights_ = AddInput(
{TensorType_INT16, {n_cell}, ranges[11].first, ranges[11].second});
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
cell_to_output_weights_ = AddNullInput();
}
if (use_cifg) {
input_gate_bias_ = AddNullInput();
} else {
input_gate_bias_ = AddInput(
{TensorType_INT32, {n_cell}, ranges[12].first, ranges[12].second});
}
forget_gate_bias_ = AddInput(
{TensorType_INT32, {n_cell}, ranges[13].first, ranges[13].second});
cell_gate_bias_ = AddInput(
{TensorType_INT32, {n_cell}, ranges[14].first, ranges[14].second});
output_gate_bias_ = AddInput(
{TensorType_INT32, {n_cell}, ranges[15].first, ranges[15].second});
if (use_projection_weights) {
projection_weights_ = AddInput({TensorType_INT8,
{n_output, n_cell},
ranges[16].first,
ranges[16].second});
} else {
projection_weights_ = AddNullInput();
}
if (use_projection_bias) {
CHECK(use_projection_weights);
projection_bias_ = AddInput(
{TensorType_INT32, {n_output}, ranges[17].first, ranges[17].second});
} else {
projection_bias_ = AddNullInput();
}
// Adding the 2 state tensors.
AddVariableInput({TensorType_INT16,
{n_batch, n_output},
ranges[18].first,
ranges[18].second});
AddVariableInput({TensorType_INT16,
{n_batch, n_cell},
ranges[19].first,
ranges[19].second});
// Layer norm weights.
if (use_layer_norm) {
if (use_cifg) {
input_layer_norm_coefficients_ = AddNullInput();
} else {
input_layer_norm_coefficients_ = AddInput(
{TensorType_INT16, {n_cell}, ranges[20].first, ranges[20].second});
}
forget_layer_norm_coefficients_ = AddInput(
{TensorType_INT16, {n_cell}, ranges[21].first, ranges[21].second});
cell_layer_norm_coefficients_ = AddInput(
{TensorType_INT16, {n_cell}, ranges[22].first, ranges[22].second});
output_layer_norm_coefficients_ = AddInput(
{TensorType_INT16, {n_cell}, ranges[23].first, ranges[23].second});
}
// use_8x8_8_implementation is not supported yet.
CHECK(!use_8x8_8_implementation);
EXPECT_EQ(intermediates.size(), 5);
for (int i = 0; i < intermediates.size(); ++i) {
AddIntermediate(TensorType_INT16, {intermediates[i].first},
{intermediates[i].second});
}
output_ = AddOutput({TensorType_INT8,
{n_batch, n_output},
ranges[24].first,
ranges[24].second});
// TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the
// default 0.
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
CreateUnidirectionalSequenceLSTMOptions(
builder_, ActivationFunctionType_TANH, /*cell_clip=*/0.0f,
/*proj_clip=*/0.0f, time_major, asymmetric_quantize_inputs)
.Union());
BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1,
/*allow_fp32_relax_to_fp16=*/false,
/*apply_delegate=*/true, /*allocate_and_delegate=*/false);
}
void PerformAllocateAndDelegate() { AllocateAndDelegate(true); }
void SetInputToInputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_input_weights_, f);
}
void SetInputToForgetWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_forget_weights_, f);
}
void SetInputToCellWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_cell_weights_, f);
}
void SetInputToOutputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_to_output_weights_, f);
}
void SetRecurrentToInputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_input_weights_, f);
}
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_forget_weights_, f);
}
void SetRecurrentToCellWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_cell_weights_, f);
}
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(recurrent_to_output_weights_, f);
}
void SetCellToInputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_to_input_weights_, f);
}
void SetCellToForgetWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_to_forget_weights_, f);
}
void SetCellToOutputWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_to_output_weights_, f);
}
void SetInputLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(input_layer_norm_coefficients_, f);
}
void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(forget_layer_norm_coefficients_, f);
}
void SetCellLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(cell_layer_norm_coefficients_, f);
}
void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
QuantizeAndPopulate<int16_t>(output_layer_norm_coefficients_, f);
}
void SetInputGateBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(input_gate_bias_, f);
}
void SetForgetGateBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(forget_gate_bias_, f);
}
void SetCellBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(cell_gate_bias_, f);
}
void SetOutputGateBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(output_gate_bias_, f);
}
void SetProjectionWeights(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(projection_weights_, f);
}
void SetProjectionBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(projection_bias_, f);
}
void SetInput(const std::vector<float>& f) {
QuantizeAndPopulate<int8_t>(input_, f);
}
std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
int num_inputs() { return n_input_; }
int num_outputs() { return n_output_; }
protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
int input_to_cell_weights_;
int input_to_output_weights_;
int recurrent_to_input_weights_;
int recurrent_to_forget_weights_;
int recurrent_to_cell_weights_;
int recurrent_to_output_weights_;
int cell_to_input_weights_;
int cell_to_forget_weights_;
int cell_to_output_weights_;
int input_layer_norm_coefficients_;
int forget_layer_norm_coefficients_;
int cell_layer_norm_coefficients_;
int output_layer_norm_coefficients_;
int input_gate_bias_;
int forget_gate_bias_;
int cell_gate_bias_;
int output_gate_bias_;
int projection_weights_;
int projection_bias_;
int output_;
int n_input_;
int n_output_;
};
TEST(IntegerUnidirectionalSequenceLstmOpTest,
NoCifg_NoPeephole_Projection_LayerNorm) {
// Hyper parameters.
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 4;
const int n_output = 3;
const int sequence_length = 3;
// Model related weights.
const std::vector<float> input_to_input_weights = {
0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
-0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
const std::vector<float> input_to_forget_weights = {
-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
-0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
const std::vector<float> input_to_cell_weights = {
-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
const std::vector<float> input_to_output_weights = {
-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
const std::vector<float> recurrent_to_input_weights = {
-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
const std::vector<float> recurrent_to_cell_weights = {
-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
const std::vector<float> recurrent_to_forget_weights = {
-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
const std::vector<float> recurrent_to_output_weights = {
0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
0.3};
const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
0.5};
const std::vector<float> projection_weights = {
-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
// Input ranges.
const std::vector<std::pair<float, float>> ranges = {
{-1.0, 127.0 / 128}, // input tensor
{-1.0, 1.0}, // input_to_input_weight tensor
{-1.0, 1.0}, // input_to_forget_weight tensor
{-1.0, 1.0}, // input_to_cell_weight tensor
{-1.0, 1.0}, // input_to_output_weight tensor
{-1.0, 1.0}, // recurrent_to_input_weight tensor
{-1.0, 1.0}, // recurrent_to_forget_weight tensor
{-1.0, 1.0}, // recurrent_to_cell_weight tensor
{-1.0, 1.0}, // recurrent_to_output_weight tensor
{-1, 1}, // cell_to_input_weight tensor
{-1, 1}, // cell_to_forget_weight tensor
{-1, 1}, // cell_to_output_weight tensor
{-100, 100}, // input_gate_bias tensor
{-100, 100}, // forget_gate_bias tensor
{-100, 100}, // cell_gate_bias tensor
{-100, 100}, // output_gate_bias tensor
{-0.5, 0.5}, // projection_weight tensor
{-1, 1}, // projection_bias tensor
{-1.0, 32767.0 / 32768}, // output_state tensor
{-1, 1}, // cell_state tensor
{-1.00001, 1.0}, // input_layer_norm_coefficient tensor
{-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
{-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
{-1.00001, 1.0}, // output_layer_norm_coefficient tensor
// Output scale is the same as output_state scale and only output_state
// scale is used in the op, so this is only provided for clarity.
{-1.0, 32767.0 / 32768}, // output tensor.
};
// The scale and zero point of intermediate tensors.
std::vector<std::pair<float, int>> intermediates = {
{0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
// Create model.
UnidirectionalSequenceLSTMIntegerOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true,
/*use_cifg=*/false, /*use_peephole=*/false,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*use_layer_norm=*/true,
/*use_8x8_8_implementation=*/false, ranges, intermediates);
// Do allocate.
lstm.PerformAllocateAndDelegate();
// Set weights.
lstm.SetInputToInputWeights(input_to_input_weights);
lstm.SetInputToCellWeights(input_to_cell_weights);
lstm.SetInputToForgetWeights(input_to_forget_weights);
lstm.SetInputToOutputWeights(input_to_output_weights);
lstm.SetInputGateBias(input_gate_bias);
lstm.SetCellBias(cell_gate_bias);
lstm.SetForgetGateBias(forget_gate_bias);
lstm.SetOutputGateBias(output_gate_bias);
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
lstm.SetProjectionWeights(projection_weights);
lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
// Model inputs. sequence -batch - input
const std::vector<float> lstm_input = {
0.7, 0.8, 0.1, 0.2, 0.3, //
0.8, 0.1, 0.2, 0.4, 0.5, //
0.2, 0.7, 0.7, 0.1, 0.7, //
0.3, 0.2, 0.9, 0.8, 0.1, //
0.7, 0.8, 0.1, 0.2, 0.3, //
0.3, 0.2, 0.9, 0.8, 0.1, //
};
// Expected outputs, n_batch * sequence_length * n_output
const std::vector<int8_t> expected_output = {
127, 127, -108, -67, 127, 127, -128, 127, 127,
-128, 127, 127, 127, 127, 127, -128, 127, 127,
};
// Invoke and verify the result.
lstm.SetInput(lstm_input);
lstm.Invoke();
EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output));
}
TEST(IntegerUnidirectionalSequenceLstmOpTest,
NoCifg_Peephole_Projection_LayerNorm) {
// Hyper parameters.
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 4;
const int n_output = 3;
const int sequence_length = 3;
// Model related weights.
const std::vector<float> input_to_input_weights = {
0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
-0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
const std::vector<float> input_to_forget_weights = {
-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
-0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
const std::vector<float> input_to_cell_weights = {
-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
const std::vector<float> input_to_output_weights = {
-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
const std::vector<float> recurrent_to_input_weights = {
-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
const std::vector<float> recurrent_to_cell_weights = {
-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
const std::vector<float> recurrent_to_forget_weights = {
-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
const std::vector<float> recurrent_to_output_weights = {
0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
const std::vector<float> cell_to_input_weights = {0.3, -0.1, 0.1, -0.2};
const std::vector<float> cell_to_forget_weights = {0.2, -0.1, 0.1, -0.2};
const std::vector<float> cell_to_output_weights = {0.3, -0.1, 0.1, -0.3};
const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
0.3};
const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
0.5};
const std::vector<float> projection_weights = {
-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
// Input ranges.
const std::vector<std::pair<float, float>> ranges = {
{-1.0, 127.0 / 128}, // input tensor
{-1.0, 1.0}, // input_to_input_weight tensor
{-1.0, 1.0}, // input_to_forget_weight tensor
{-1.0, 1.0}, // input_to_cell_weight tensor
{-1.0, 1.0}, // input_to_output_weight tensor
{-1.0, 1.0}, // recurrent_to_input_weight tensor
{-0.9, 0.9}, // recurrent_to_forget_weight tensor
{-1.0, 1.0}, // recurrent_to_cell_weight tensor
{-1.0, 1.0}, // recurrent_to_output_weight tensor
{-0.3, 0.3}, // cell_to_input_weight tensor
{-0.3, 0.3}, // cell_to_forget_weight tensor
{-0.3, 0.3}, // cell_to_output_weight tensor
{-100, 100}, // input_gate_bias tensor
{-100, 80}, // forget_gate_bias tensor
{-100, 100}, // cell_gate_bias tensor
{-100, 100}, // output_gate_bias tensor
{-0.5, 0.5}, // projection_weight tensor
{-1, 1}, // projection_bias tensor
{-1.0, 32767.0 / 32768}, // output_state tensor
{-1, 1}, // cell_state tensor
{-0.5, 0.5}, // input_layer_norm_coefficient tensor
{-0.5, 0.5}, // forget_layer_norm_coefficient tensor
{-1.0, 1.0}, // cell_layer_norm_coefficient tensor
{-1.0, 1.0}, // output_layer_norm_coefficient tensor
// Output scale is the same as output_state scale and only output_state
// scale is used in the op, so this is only provided for clarity.
{-1.0, 32767.0 / 32768}, // output tensor.
};
// The scale and zero point of intermediate tensors.
std::vector<std::pair<float, int>> intermediates = {
{0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
// Create model.
UnidirectionalSequenceLSTMIntegerOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*use_layer_norm=*/true,
/*use_8x8_8_implementation=*/false, ranges, intermediates);
// Do allocate.
lstm.PerformAllocateAndDelegate();
// Set weights.
lstm.SetInputToInputWeights(input_to_input_weights);
lstm.SetInputToCellWeights(input_to_cell_weights);
lstm.SetInputToForgetWeights(input_to_forget_weights);
lstm.SetInputToOutputWeights(input_to_output_weights);
lstm.SetInputGateBias(input_gate_bias);
lstm.SetCellBias(cell_gate_bias);
lstm.SetForgetGateBias(forget_gate_bias);
lstm.SetOutputGateBias(output_gate_bias);
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
lstm.SetCellToInputWeights(cell_to_input_weights);
lstm.SetCellToForgetWeights(cell_to_forget_weights);
lstm.SetCellToOutputWeights(cell_to_output_weights);
lstm.SetProjectionWeights(projection_weights);
lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
// Model inputs. sequence -batch - input
const std::vector<float> lstm_input = {
0.7, 0.8, 0.1, 0.2, 0.3, //
0.8, 0.1, 0.2, 0.4, 0.5, //
0.2, 0.7, 0.7, 0.1, 0.7, //
0.3, 0.2, 0.9, 0.8, 0.1, //
0.7, 0.8, 0.1, 0.2, 0.3, //
0.3, 0.2, 0.9, 0.8, 0.1, //
};
// Expected outputs, n_batch * sequence_length * n_output
const std::vector<int8_t> expected_output = {
127, 127, -16, -21, 127, 127, 23, 127, 127,
-128, 127, 127, 127, 127, 127, -128, 127, 127,
};
// Invoke and verify the result.
lstm.SetInput(lstm_input);
lstm.Invoke();
EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output));
}
#define QUANTIZE_PARAMETER_TEST(test) \
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true}));

View File

@ -825,6 +825,9 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model,
if (input.second.number_of_bits == 8 &&
input.second.symmetric == false) {
TensorT* tensor = subgraph->tensors[index_global].get();
if (tensor->quantization == nullptr) {
continue;
}
if (utils::HasMinMax(tensor)) {
utils::QuantizeActivation(tensor, activations_type,
error_reporter);