Make the names of "activation", "activation_state", "output_state" variables consistent: Use "output_state", used by the float implementation (LstmStepFloat).
Remove a few variables in implementation and test that were representing the same value (the index of these tensors, mostly). Also rename "input cell state" to just "cell state". PiperOrigin-RevId: 316908413 Change-Id: Icb64ecd31c90f45ef21cf7d48849fb2ec0975d3a
This commit is contained in:
parent
c870b9f920
commit
406d9b5521
tensorflow/lite
kernels
lstm.cclstm_eval.cclstm_eval.hlstm_shared.hlstm_test.ccunidirectional_sequence_lstm.ccunidirectional_sequence_lstm_test.cc
tools/optimize/calibration/builtin_logging_ops
@ -68,19 +68,19 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
const float cell_clip = params->cell_clip;
|
||||
const float proj_clip = params->proj_clip;
|
||||
|
||||
const TfLiteTensor* cell_tensor =
|
||||
GetVariableInput(context, node, kInputCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_tensor != nullptr);
|
||||
const TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
auto* cell_params =
|
||||
static_cast<TfLiteAffineQuantization*>(cell_tensor->quantization.params);
|
||||
auto* cell_state_params =
|
||||
static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
|
||||
auto* proj_params = static_cast<TfLiteAffineQuantization*>(
|
||||
output_tensor->quantization.params);
|
||||
if (cell_clip > 0.0) {
|
||||
integer_lstm_param->quantized_cell_clip = static_cast<int32_t>(
|
||||
std::min(std::max(cell_clip / cell_params->scale->data[0], -32768.0f),
|
||||
32767.0f));
|
||||
integer_lstm_param->quantized_cell_clip = static_cast<int32_t>(std::min(
|
||||
std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
|
||||
32767.0f));
|
||||
} else {
|
||||
integer_lstm_param->quantized_cell_clip = 0;
|
||||
}
|
||||
@ -134,9 +134,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
const TfLiteTensor* projection_weights =
|
||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to get the condition.
|
||||
@ -187,7 +187,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
float layer_norm_forget_scale = default_scale;
|
||||
float layer_norm_cell_scale = default_scale;
|
||||
float layer_norm_output_scale = default_scale;
|
||||
float activation_scale = default_scale;
|
||||
float output_state_scale = default_scale;
|
||||
int cell_scale = 1;
|
||||
|
||||
// Effective scales.
|
||||
@ -231,7 +231,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
if (use_projection) {
|
||||
proj_weight_scale = projection_weights->params.scale;
|
||||
}
|
||||
activation_scale = activation_state->params.scale;
|
||||
output_state_scale = output_state->params.scale;
|
||||
|
||||
input_to_forget_weight_scale = input_to_forget_weights->params.scale;
|
||||
input_to_cell_weight_scale = input_to_cell_weights->params.scale;
|
||||
@ -240,12 +240,8 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
|
||||
recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
|
||||
|
||||
// Get cell state.
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, kInputCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
// Check cell state (already used above)
|
||||
TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
|
||||
|
||||
TF_LITE_ENSURE(context, cell_scale <= -9);
|
||||
integer_lstm_param->cell_scale = cell_scale;
|
||||
input_scale = input->params.scale;
|
||||
@ -255,31 +251,32 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
effective_input_to_input_scale =
|
||||
input_to_input_weight_scale * input_scale / intermediate_scale[0];
|
||||
effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
|
||||
activation_scale /
|
||||
output_state_scale /
|
||||
intermediate_scale[0];
|
||||
}
|
||||
effective_input_to_forget_scale =
|
||||
input_to_forget_weight_scale * input_scale / intermediate_scale[1];
|
||||
effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
|
||||
activation_scale /
|
||||
output_state_scale /
|
||||
intermediate_scale[1];
|
||||
|
||||
effective_input_to_cell_scale =
|
||||
input_to_cell_weight_scale * input_scale / intermediate_scale[2];
|
||||
effective_recurrent_to_cell_scale =
|
||||
recurrent_to_cell_weight_scale * activation_scale / intermediate_scale[2];
|
||||
effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
|
||||
output_state_scale /
|
||||
intermediate_scale[2];
|
||||
|
||||
effective_input_to_output_scale =
|
||||
input_to_output_weight_scale * input_scale / intermediate_scale[3];
|
||||
effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
|
||||
activation_scale /
|
||||
output_state_scale /
|
||||
intermediate_scale[3];
|
||||
|
||||
effective_hidden_scale =
|
||||
std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
|
||||
|
||||
effective_proj_scale =
|
||||
proj_weight_scale * intermediate_scale[4] / activation_scale;
|
||||
proj_weight_scale * intermediate_scale[4] / output_state_scale;
|
||||
|
||||
if (use_peephole) {
|
||||
if (!use_cifg) {
|
||||
@ -419,11 +416,10 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
const TfLiteTensor* projection_bias =
|
||||
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, kInputCellStateTensor);
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
@ -456,7 +452,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
int32_t* output_bias_ptr = nullptr;
|
||||
int32_t* proj_bias_ptr = nullptr;
|
||||
int16_t* cell_ptr = nullptr;
|
||||
int8_t* activation_ptr = nullptr;
|
||||
int8_t* output_state_ptr = nullptr;
|
||||
|
||||
// Scales.
|
||||
const float default_scale = 1.0;
|
||||
@ -477,7 +473,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
float layer_norm_forget_scale = default_scale;
|
||||
float layer_norm_cell_scale = default_scale;
|
||||
float layer_norm_output_scale = default_scale;
|
||||
float activation_scale = default_scale;
|
||||
float output_state_scale = default_scale;
|
||||
|
||||
// Effective scales.
|
||||
float effective_input_to_input_scale = default_scale;
|
||||
@ -495,7 +491,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
|
||||
// Zero points
|
||||
int input_zp = 0;
|
||||
int activation_zp = 0;
|
||||
int output_state_zp = 0;
|
||||
|
||||
// Populate all the values.
|
||||
if (!use_cifg) {
|
||||
@ -537,7 +533,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
proj_bias_ptr = projection_bias->data.i32;
|
||||
}
|
||||
}
|
||||
activation_scale = activation_state->params.scale;
|
||||
output_state_scale = output_state->params.scale;
|
||||
|
||||
input_to_forget_weight_ptr = input_to_forget_weights->data.int8;
|
||||
input_to_forget_weight_scale = input_to_forget_weights->params.scale;
|
||||
@ -554,11 +550,11 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
forget_bias_ptr = forget_gate_bias->data.i32;
|
||||
cell_bias_ptr = cell_bias->data.i32;
|
||||
output_bias_ptr = output_gate_bias->data.i32;
|
||||
activation_ptr = activation_state->data.int8;
|
||||
output_state_ptr = output_state->data.int8;
|
||||
cell_ptr = cell_state->data.i16;
|
||||
input_scale = input->params.scale;
|
||||
input_zp = input->params.zero_point;
|
||||
activation_zp = activation_state->params.zero_point;
|
||||
output_state_zp = output_state->params.zero_point;
|
||||
|
||||
std::vector<float> intermediate_scale;
|
||||
for (int i = 0; i < 12; ++i) {
|
||||
@ -575,27 +571,28 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
effective_input_to_input_scale =
|
||||
input_to_input_weight_scale * input_scale / intermediate_scale[1];
|
||||
effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
|
||||
activation_scale /
|
||||
output_state_scale /
|
||||
intermediate_scale[2];
|
||||
}
|
||||
effective_input_to_forget_scale =
|
||||
input_to_forget_weight_scale * input_scale / intermediate_scale[4];
|
||||
effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
|
||||
activation_scale /
|
||||
output_state_scale /
|
||||
intermediate_scale[5];
|
||||
|
||||
effective_input_to_cell_scale =
|
||||
input_to_cell_weight_scale * input_scale / intermediate_scale[7];
|
||||
effective_recurrent_to_cell_scale =
|
||||
recurrent_to_cell_weight_scale * activation_scale / intermediate_scale[8];
|
||||
effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
|
||||
output_state_scale /
|
||||
intermediate_scale[8];
|
||||
|
||||
effective_input_to_output_scale =
|
||||
input_to_output_weight_scale * input_scale / intermediate_scale[10];
|
||||
effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
|
||||
activation_scale /
|
||||
output_state_scale /
|
||||
intermediate_scale[11];
|
||||
effective_proj_scale =
|
||||
proj_weight_scale * std::pow(2, -15) / activation_scale;
|
||||
proj_weight_scale * std::pow(2, -15) / output_state_scale;
|
||||
|
||||
if (use_peephole) {
|
||||
if (!use_cifg) {
|
||||
@ -698,18 +695,16 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
|
||||
const float cell_clip = params->cell_clip;
|
||||
const float proj_clip = params->proj_clip;
|
||||
|
||||
const TfLiteTensor* cell_tensor =
|
||||
GetInput(context, node, kInputCellStateTensor);
|
||||
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
auto* cell_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
cell_tensor->quantization.params);
|
||||
auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
cell_state->quantization.params);
|
||||
auto* proj_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
output_tensor->quantization.params);
|
||||
TF_LITE_ENSURE_EQ(context, cell_params->scale->data[0], 1.0 / 32768);
|
||||
TF_LITE_ENSURE_EQ(context, cell_state_params->scale->data[0], 1.0 / 32768);
|
||||
if (cell_clip > 0.0 && cell_clip < 1.0) {
|
||||
integer_lstm_param->quantized_cell_clip =
|
||||
static_cast<int>(cell_clip / cell_params->scale->data[0]);
|
||||
static_cast<int>(cell_clip / cell_state_params->scale->data[0]);
|
||||
} else {
|
||||
integer_lstm_param->quantized_cell_clip = 0;
|
||||
}
|
||||
@ -1026,12 +1021,12 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
OpData* op_data,
|
||||
TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
const TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
|
||||
const int32_t input_zero_point = -input->params.zero_point;
|
||||
const int32_t activation_zero_point = -activation_state->params.zero_point;
|
||||
const int32_t output_state_zero_point = -output_state->params.zero_point;
|
||||
|
||||
const TfLiteTensor* input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||
@ -1083,8 +1078,8 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
PrecomputeZeroPointTimesWeightWithBias(
|
||||
context, activation_zero_point, recurrent_to_forget_weights, nullptr,
|
||||
&(integer_lstm_params->recurrent_to_forget_effective_bias)));
|
||||
context, output_state_zero_point, recurrent_to_forget_weights,
|
||||
nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
|
||||
|
||||
// Modulation gate.
|
||||
const TfLiteTensor* cell_gate_bias =
|
||||
@ -1097,7 +1092,7 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
PrecomputeZeroPointTimesWeightWithBias(
|
||||
context, activation_zero_point, recurrent_to_cell_weights, nullptr,
|
||||
context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
|
||||
&(integer_lstm_params->recurrent_to_cell_effective_bias)));
|
||||
|
||||
// Output gate.
|
||||
@ -1112,8 +1107,8 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
PrecomputeZeroPointTimesWeightWithBias(
|
||||
context, activation_zero_point, recurrent_to_output_weights, nullptr,
|
||||
&(integer_lstm_params->recurrent_to_output_effective_bias)));
|
||||
context, output_state_zero_point, recurrent_to_output_weights,
|
||||
nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
|
||||
|
||||
// Input gate. The calculation is only meaningful for non-cifg case.
|
||||
const TfLiteTensor* input_gate_bias =
|
||||
@ -1126,7 +1121,7 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
PrecomputeZeroPointTimesWeightWithBias(
|
||||
context, activation_zero_point, recurrent_to_input_weights, nullptr,
|
||||
context, output_state_zero_point, recurrent_to_input_weights, nullptr,
|
||||
&(integer_lstm_params->recurrent_to_input_effective_bias)));
|
||||
|
||||
// Projection bias. The calculation is only meaningful for with projection.
|
||||
@ -1198,20 +1193,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context, CheckInputTensorDimensions(context, node, n_input, n_output,
|
||||
n_cell, use_layer_norm, is_integer));
|
||||
|
||||
// Get the pointer to output, activation_state and cell_state tensors.
|
||||
// Get the pointer to output, output_state and cell_state tensors.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, kInputCellStateTensor);
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
|
||||
// Check the shape of input state tensors.
|
||||
// These tensor may be 1D or 2D. It's fine as long as the total size is
|
||||
// correct.
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
|
||||
|
||||
// Resize the output tensors.
|
||||
@ -1275,7 +1269,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (is_hybrid_op) {
|
||||
op_data->compute_row_sums = true;
|
||||
// Allocate temporary tensors to store quantized values of input,
|
||||
// activation_state and cell_state tensors.
|
||||
// output_state and cell_state tensors.
|
||||
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
||||
input_quantized->type = input_to_output_weights->type;
|
||||
@ -1286,17 +1280,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_quantized_size));
|
||||
}
|
||||
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||
TfLiteTensor* activation_state_quantized =
|
||||
TfLiteTensor* output_state_quantized =
|
||||
GetTemporary(context, node, /*index=*/2);
|
||||
activation_state_quantized->type = input_to_output_weights->type;
|
||||
activation_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
|
||||
activation_state->dims)) {
|
||||
TfLiteIntArray* activation_state_quantized_size =
|
||||
TfLiteIntArrayCopy(activation_state->dims);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, activation_state_quantized,
|
||||
activation_state_quantized_size));
|
||||
output_state_quantized->type = input_to_output_weights->type;
|
||||
output_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
|
||||
output_state->dims)) {
|
||||
TfLiteIntArray* output_state_quantized_size =
|
||||
TfLiteIntArrayCopy(output_state->dims);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output_state_quantized,
|
||||
output_state_quantized_size));
|
||||
}
|
||||
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
|
||||
TfLiteTensor* cell_state_quantized =
|
||||
@ -1540,11 +1534,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* projection_bias =
|
||||
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, kInputCellStateTensor);
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
@ -1569,7 +1562,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||
projection_bias, params, /*forward_sequence=*/true,
|
||||
/*time_major=*/true,
|
||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||
/*output_offset=*/0, scratch_buffer, output_state, cell_state,
|
||||
output);
|
||||
}
|
||||
case kTfLiteUInt8:
|
||||
@ -1580,7 +1573,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* activation_state_quantized =
|
||||
TfLiteTensor* output_state_quantized =
|
||||
GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* cell_state_quantized =
|
||||
GetTemporary(context, node, /*index=*/3);
|
||||
@ -1614,8 +1607,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*time_major=*/true, /*output_offset=*/0, scratch_buffer,
|
||||
scaling_factors, prod_scaling_factors, recovered_cell_weights,
|
||||
input_quantized,
|
||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||
cell_state_quantized, activation_state, cell_state,
|
||||
/*aux_input_quantized=*/nullptr, output_state_quantized,
|
||||
cell_state_quantized, output_state, cell_state,
|
||||
output_scratch_buffer, output, zero_points, row_sums, row_sums_size,
|
||||
&op_data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
@ -1638,9 +1631,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
|
||||
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
|
||||
cell_bias, output_gate_bias, projection_weights, projection_bias,
|
||||
params, &op_data->integer_lstm_param, activation_state,
|
||||
cell_state, output, scratch0, scratch1, scratch2, scratch3,
|
||||
scratch4, scratch5, CpuBackendContext::GetFromContext(context));
|
||||
params, &op_data->integer_lstm_param, output_state, cell_state,
|
||||
output, scratch0, scratch1, scratch2, scratch3, scratch4,
|
||||
scratch5, CpuBackendContext::GetFromContext(context));
|
||||
} else {
|
||||
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);
|
||||
@ -1660,7 +1653,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
|
||||
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
|
||||
cell_bias, output_gate_bias, projection_weights, projection_bias,
|
||||
params, activation_state, cell_state, output,
|
||||
params, output_state, cell_state, output,
|
||||
&op_data->integer_lstm_param, scratch0, scratch1, scratch2,
|
||||
scratch3, scratch4, scratch5, scratch6, scratch7);
|
||||
return kTfLiteOk;
|
||||
|
@ -900,7 +900,7 @@ inline void LstmStepHybrid(
|
||||
|
||||
// Fully quantized lstm kernel for 16 bit gate matmul output.
|
||||
//
|
||||
// Input activation of size n_batch * n_input:
|
||||
// Input tensor of size n_batch * n_input:
|
||||
// input_ptr
|
||||
//
|
||||
// LSTM weights:
|
||||
@ -972,7 +972,7 @@ inline void LstmStepHybrid(
|
||||
// cell_scale: the power of two scale for cell state.
|
||||
//
|
||||
// Zero points:
|
||||
// activation_zp: zero point of activation
|
||||
// output_state_zp: zero point of output state
|
||||
// hidden_zp: zero point for hidden state.
|
||||
//
|
||||
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
|
||||
@ -1048,8 +1048,8 @@ inline void LstmStepInteger(
|
||||
const int32_t* input_to_input_effective_bias,
|
||||
const int32_t* recurrent_to_input_effective_bias,
|
||||
const int32_t* projection_effective_bias, int32 n_batch, int32 n_cell,
|
||||
int32 n_input, int32 n_output, int8_t* activation_ptr,
|
||||
int32_t activation_zp, int16_t* cell_ptr, int8_t* output_ptr,
|
||||
int32 n_input, int32 n_output, int8_t* output_state_ptr,
|
||||
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
|
||||
int16_t* scratch_0_ptr, int16_t* scratch_1_ptr, int16_t* scratch_2_ptr,
|
||||
int16_t* scratch_3_ptr, int8_t* scratch_4_ptr, int32_t* scratch_5_ptr,
|
||||
CpuBackendContext* context) {
|
||||
@ -1088,7 +1088,7 @@ inline void LstmStepInteger(
|
||||
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_1_ptr, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
activation_ptr, recurrent_to_forget_effective_bias,
|
||||
output_state_ptr, recurrent_to_forget_effective_bias,
|
||||
recurrent_to_forget_weight_ptr, effective_recurrent_to_forget_scale_a,
|
||||
effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_1_ptr, context);
|
||||
@ -1115,7 +1115,7 @@ inline void LstmStepInteger(
|
||||
n_input, n_cell, 0, scratch_5_ptr, scratch_2_ptr, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
activation_ptr, recurrent_to_cell_effective_bias,
|
||||
output_state_ptr, recurrent_to_cell_effective_bias,
|
||||
recurrent_to_cell_weight_ptr, effective_recurrent_to_cell_scale_a,
|
||||
effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_2_ptr, context);
|
||||
@ -1139,7 +1139,7 @@ inline void LstmStepInteger(
|
||||
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_0_ptr, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
activation_ptr, recurrent_to_input_effective_bias,
|
||||
output_state_ptr, recurrent_to_input_effective_bias,
|
||||
recurrent_to_input_weight_ptr, effective_recurrent_to_input_scale_a,
|
||||
effective_recurrent_to_input_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_0_ptr, context);
|
||||
@ -1180,7 +1180,7 @@ inline void LstmStepInteger(
|
||||
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_3_ptr, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
activation_ptr, recurrent_to_output_effective_bias,
|
||||
output_state_ptr, recurrent_to_output_effective_bias,
|
||||
recurrent_to_output_weight_ptr, effective_recurrent_to_output_scale_a,
|
||||
effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_3_ptr, context);
|
||||
@ -1213,7 +1213,7 @@ inline void LstmStepInteger(
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
scratch_4_ptr, projection_effective_bias, proj_weight_ptr,
|
||||
effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell,
|
||||
n_output, activation_zp, scratch_5_ptr, output_ptr, context);
|
||||
n_output, output_state_zp, scratch_5_ptr, output_ptr, context);
|
||||
if (quantized_proj_clip > 0) {
|
||||
tensor_utils::CwiseClipping(output_ptr, quantized_proj_clip, n_batch,
|
||||
n_output);
|
||||
@ -1221,12 +1221,12 @@ inline void LstmStepInteger(
|
||||
} else {
|
||||
std::copy_n(scratch_4_ptr, n_batch * n_output, output_ptr);
|
||||
}
|
||||
std::copy_n(output_ptr, n_batch * n_output, activation_ptr);
|
||||
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
|
||||
}
|
||||
|
||||
// Fully quantized lstm kernel for 8 bit gate matmul output.
|
||||
//
|
||||
// Input activation of size n_batch * n_input:
|
||||
// Input tensor of size n_batch * n_input:
|
||||
// input_ptr
|
||||
//
|
||||
// LSTM weights:
|
||||
@ -1298,7 +1298,7 @@ inline void LstmStepInteger(
|
||||
// cell_scale: the power of two scale for cell state.
|
||||
//
|
||||
// Zero points:
|
||||
// activation_zp: zero point of activation
|
||||
// output_state_zp: zero point of output state.
|
||||
// hidden_zp: zero point for hidden state.
|
||||
//
|
||||
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
|
||||
@ -1367,8 +1367,8 @@ void LstmStepInteger(
|
||||
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
|
||||
const int32_t* intermediate_zp, int32 quantized_cell_clip,
|
||||
int32 quantized_proj_clip, int32 n_batch, int32 n_cell, int32 n_input,
|
||||
int32 n_output, int32 output_batch_leading_dim, int8_t* activation_ptr,
|
||||
int32_t activation_zp, int16_t* cell_ptr, int8_t* output_ptr,
|
||||
int32 n_output, int32 output_batch_leading_dim, int8_t* output_state_ptr,
|
||||
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
|
||||
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
|
||||
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
|
||||
int16_t* scratch7) {
|
||||
@ -1381,7 +1381,7 @@ void LstmStepInteger(
|
||||
n_batch, n_input, n_cell, scratch0, intermediate_zp[4]);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiply(
|
||||
activation_ptr, activation_zp, recurrent_to_forget_weight_ptr,
|
||||
output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
|
||||
effective_recurrent_to_forget_scale_a,
|
||||
effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell,
|
||||
scratch1, intermediate_zp[5]);
|
||||
@ -1408,7 +1408,7 @@ void LstmStepInteger(
|
||||
n_input, n_cell, scratch0, intermediate_zp[7]);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiply(
|
||||
activation_ptr, activation_zp, recurrent_to_cell_weight_ptr,
|
||||
output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
|
||||
effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
|
||||
n_batch, n_output, n_cell, scratch1, intermediate_zp[8]);
|
||||
|
||||
@ -1434,7 +1434,7 @@ void LstmStepInteger(
|
||||
n_batch, n_input, n_cell, scratch0, intermediate_zp[10]);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiply(
|
||||
activation_ptr, activation_zp, recurrent_to_output_weight_ptr,
|
||||
output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
|
||||
effective_recurrent_to_output_scale_a,
|
||||
effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell,
|
||||
scratch1, intermediate_zp[11]);
|
||||
@ -1478,7 +1478,7 @@ void LstmStepInteger(
|
||||
// Projection.
|
||||
tensor_utils::MatrixBatchVectorMultiply(
|
||||
scratch3, proj_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
|
||||
proj_bias_ptr, n_batch, n_cell, n_output, activation_zp, output_ptr);
|
||||
proj_bias_ptr, n_batch, n_cell, n_output, output_state_zp, output_ptr);
|
||||
|
||||
// Projection clipping.
|
||||
if (quantized_proj_clip > 0) {
|
||||
@ -1486,8 +1486,8 @@ void LstmStepInteger(
|
||||
n_output);
|
||||
}
|
||||
|
||||
// Copy output to activation.
|
||||
std::copy_n(output_ptr, n_batch * n_output, activation_ptr);
|
||||
// Copy output to output state.
|
||||
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -1518,9 +1518,8 @@ TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||
int output_offset, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output) {
|
||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||
int max_time, n_batch;
|
||||
if (input->dims->size == 3) {
|
||||
@ -1604,10 +1603,9 @@ TfLiteStatus EvalFloat(
|
||||
GetTensorData<float>(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
GetTensorData<float>(activation_state),
|
||||
GetTensorData<float>(cell_state), input_gate_scratch,
|
||||
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
|
||||
output_ptr);
|
||||
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
||||
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
|
||||
output_gate_scratch, output_ptr);
|
||||
}
|
||||
} else {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
@ -1628,9 +1626,9 @@ TfLiteStatus EvalFloat(
|
||||
float* output_ptr = GetTensorData<float>(output) +
|
||||
time_offset * output_step + output_offset;
|
||||
|
||||
// Offset the {activation,cell}_state pointers to the right batch.
|
||||
float* activation_state_ptr = GetTensorData<float>(activation_state) +
|
||||
b * output_batch_leading_dim;
|
||||
// Offset the {output,cell}_state pointers to the right batch.
|
||||
float* output_state_ptr =
|
||||
GetTensorData<float>(output_state) + b * output_batch_leading_dim;
|
||||
float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
|
||||
// Offset the scratch pointers to the right batch.
|
||||
float* input_gate_scratch_ptr =
|
||||
@ -1666,7 +1664,7 @@ TfLiteStatus EvalFloat(
|
||||
GetTensorData<float>(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
|
||||
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
||||
output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
||||
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
|
||||
output_gate_scratch_ptr, output_ptr);
|
||||
}
|
||||
@ -1939,10 +1937,10 @@ TfLiteStatus EvalInteger8x8_16(
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params,
|
||||
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output, TfLiteTensor* scratch0, TfLiteTensor* scratch1,
|
||||
TfLiteTensor* scratch2, TfLiteTensor* scratch3, TfLiteTensor* scratch4,
|
||||
TfLiteTensor* scratch5, CpuBackendContext* context) {
|
||||
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
|
||||
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
|
||||
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
|
||||
CpuBackendContext* context) {
|
||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||
const int n_input = input->dims->data[input->dims->size - 1];
|
||||
int max_time, n_batch;
|
||||
@ -1959,7 +1957,7 @@ TfLiteStatus EvalInteger8x8_16(
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Activation zero point
|
||||
int activation_zp = activation_state->params.zero_point;
|
||||
int output_state_zp = output_state->params.zero_point;
|
||||
|
||||
// Get params for time/batch/sequence.
|
||||
const int output_batch_leading_dim =
|
||||
@ -2042,8 +2040,8 @@ TfLiteStatus EvalInteger8x8_16(
|
||||
integer_lstm_param->input_to_input_effective_bias.get(),
|
||||
integer_lstm_param->recurrent_to_input_effective_bias.get(),
|
||||
integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
|
||||
n_input, n_output, GetTensorData<int8_t>(activation_state),
|
||||
activation_zp, GetTensorData<int16_t>(cell_state), output_ptr,
|
||||
n_input, n_output, GetTensorData<int8_t>(output_state), output_state_zp,
|
||||
GetTensorData<int16_t>(cell_state), output_ptr,
|
||||
GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1),
|
||||
GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
|
||||
GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5),
|
||||
@ -2072,7 +2070,7 @@ TfLiteStatus EvalInteger8x8_8(
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* activation_state,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output,
|
||||
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
|
||||
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
|
||||
@ -2131,11 +2129,11 @@ TfLiteStatus EvalInteger8x8_8(
|
||||
const int32_t* output_bias_ptr = GetTensorData<int32_t>(output_gate_bias);
|
||||
const int32_t* proj_bias_ptr = GetTensorData<int32_t>(projection_bias);
|
||||
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
|
||||
int8_t* activation_ptr = GetTensorData<int8_t>(activation_state);
|
||||
int8_t* output_state_ptr = GetTensorData<int8_t>(output_state);
|
||||
int8_t* output_ptr = nullptr;
|
||||
|
||||
const int32 input_zp = input->params.zero_point;
|
||||
const int32 activation_zp = activation_state->params.zero_point;
|
||||
const int32 output_state_zp = output_state->params.zero_point;
|
||||
|
||||
// Get params for time/batch/sequence.
|
||||
const int output_batch_leading_dim =
|
||||
@ -2222,7 +2220,7 @@ TfLiteStatus EvalInteger8x8_8(
|
||||
integer_lstm_param->intermediate_zp,
|
||||
integer_lstm_param->quantized_cell_clip,
|
||||
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
|
||||
n_output, output_batch_leading_dim, activation_ptr, activation_zp,
|
||||
n_output, output_batch_leading_dim, output_state_ptr, output_state_zp,
|
||||
cell_ptr, output_ptr, GetTensorData<int8_t>(scratch0),
|
||||
GetTensorData<int8_t>(scratch1), GetTensorData<int16_t>(scratch2),
|
||||
GetTensorData<int16_t>(scratch3), GetTensorData<int16_t>(scratch4),
|
||||
|
@ -120,9 +120,8 @@ TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||
int output_offset, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output);
|
||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output);
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
@ -179,10 +178,10 @@ TfLiteStatus EvalInteger8x8_16(
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params,
|
||||
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output, TfLiteTensor* scratch0, TfLiteTensor* scratch1,
|
||||
TfLiteTensor* scratch2, TfLiteTensor* scratch3, TfLiteTensor* scratch4,
|
||||
TfLiteTensor* scratch5, CpuBackendContext* context);
|
||||
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
|
||||
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
|
||||
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
|
||||
CpuBackendContext* context);
|
||||
|
||||
TfLiteStatus EvalInteger8x8_8(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
@ -203,7 +202,7 @@ TfLiteStatus EvalInteger8x8_8(
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* activation_state,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output,
|
||||
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
|
||||
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
|
||||
|
@ -57,8 +57,8 @@ constexpr int kProjectionBiasTensor = 17; // Optional
|
||||
|
||||
// These state tensors are defined as variable tensors, and will be modified by
|
||||
// this op.
|
||||
constexpr int kInputActivationStateTensor = 18;
|
||||
constexpr int kInputCellStateTensor = 19;
|
||||
constexpr int kOutputStateTensor = 18;
|
||||
constexpr int kCellStateTensor = 19;
|
||||
|
||||
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
|
||||
// matrix.
|
||||
|
@ -104,10 +104,10 @@ class LSTMOpModel : public SingleOpModel {
|
||||
projection_bias_ = AddNullInput();
|
||||
}
|
||||
|
||||
// Adding the 2 input state tensors.
|
||||
input_activation_state_ =
|
||||
// Adding the 2 state tensors.
|
||||
output_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
|
||||
input_cell_state_ =
|
||||
cell_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
|
||||
|
||||
// Layer norm weights.
|
||||
@ -266,13 +266,11 @@ class LSTMOpModel : public SingleOpModel {
|
||||
|
||||
int projection_weights_;
|
||||
int projection_bias_;
|
||||
int input_activation_state_;
|
||||
int input_cell_state_;
|
||||
|
||||
int output_;
|
||||
int output_state_;
|
||||
int cell_state_;
|
||||
|
||||
int output_;
|
||||
|
||||
int n_batch_;
|
||||
int n_input_;
|
||||
int n_cell_;
|
||||
@ -553,7 +551,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
@ -1697,7 +1695,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_cell}, // input_layer_norm_coefficient tensor
|
||||
@ -1768,7 +1766,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_cell}, // input_layer_norm_coefficient tensor
|
||||
@ -1841,7 +1839,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_cell}, // input_layer_norm_coefficient tensor
|
||||
@ -1955,7 +1953,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
@ -2026,7 +2024,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
@ -2098,7 +2096,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
@ -2216,13 +2214,13 @@ class LSTMIntegerOpModel : public SingleOpModel {
|
||||
projection_bias_ = AddNullInput();
|
||||
}
|
||||
|
||||
// Adding the 2 input state tensors.
|
||||
input_activation_state_ = AddInput({TensorType_INT16, input_shapes[18],
|
||||
ranges[18].first, ranges[18].second},
|
||||
true);
|
||||
input_cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
|
||||
ranges[19].first, ranges[19].second},
|
||||
true);
|
||||
// Adding the 2 state tensors.
|
||||
output_state_ = AddInput({TensorType_INT16, input_shapes[18],
|
||||
ranges[18].first, ranges[18].second},
|
||||
true);
|
||||
cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
|
||||
ranges[19].first, ranges[19].second},
|
||||
true);
|
||||
|
||||
// Layer norm weights.
|
||||
if (use_layer_norm) {
|
||||
@ -2386,8 +2384,6 @@ class LSTMIntegerOpModel : public SingleOpModel {
|
||||
|
||||
int projection_weights_;
|
||||
int projection_bias_;
|
||||
int input_activation_state_;
|
||||
int input_cell_state_;
|
||||
|
||||
int intermediates_[5];
|
||||
|
||||
@ -2483,7 +2479,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) {
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_cell}, // input_layer_norm_coefficient tensor
|
||||
@ -2517,14 +2513,14 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) {
|
||||
{-0.5, 0.5}, // projection_weight tensor
|
||||
{-1, 1}, // projection_bias tensor
|
||||
|
||||
{-1.0, 32767.0 / 32768}, // activation_state tensor
|
||||
{-1.0, 32767.0 / 32768}, // output_state tensor
|
||||
{-1, 1}, // cell_state tensor
|
||||
|
||||
{-1.00001, 1.0}, // input_layer_norm_coefficient tensor
|
||||
{-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
|
||||
{-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
|
||||
{-1.00001, 1.0}, // output_layer_norm_coefficient tensor
|
||||
// Output scale is the same as input activation scale and only activation
|
||||
// Output scale is the same as output_state scale and only output_state
|
||||
// scale is used in the op, so this is only provided for clarity.
|
||||
{-1.0, 32767.0 / 32768}, // output tensor.
|
||||
};
|
||||
@ -2685,7 +2681,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionYesPeephole) {
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_cell}, // input_layer_norm_coefficient tensor
|
||||
@ -2719,14 +2715,14 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionYesPeephole) {
|
||||
{-0.5, 0.5}, // projection_weight tensor
|
||||
{-1, 1}, // projection_bias tensor
|
||||
|
||||
{-1.0, 32767.0 / 32768}, // activation_state tensor
|
||||
{-1.0, 32767.0 / 32768}, // output_state tensor
|
||||
{-1, 1}, // cell_state tensor
|
||||
|
||||
{-0.5, 0.5}, // input_layer_norm_coefficient tensor
|
||||
{-0.5, 0.5}, // forget_layer_norm_coefficient tensor
|
||||
{-1.0, 1.0}, // cell_layer_norm_coefficient tensor
|
||||
{-1.0, 1.0}, // output_layer_norm_coefficient tensor
|
||||
// Output scale is the same as input activation scale and only activation
|
||||
// Output scale is the same as output_state scale and only output_state
|
||||
// scale is used in the op, so this is only provided for clarity.
|
||||
{-1.0, 32767.0 / 32768}, // output tensor.
|
||||
};
|
||||
@ -2892,13 +2888,13 @@ class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
|
||||
projection_bias_ = AddNullInput();
|
||||
}
|
||||
|
||||
// Adding the 2 input state tensors.
|
||||
input_activation_state_ = AddInput({TensorType_INT16, input_shapes[18],
|
||||
ranges[18].first, ranges[18].second},
|
||||
true);
|
||||
input_cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
|
||||
ranges[19].first, ranges[19].second},
|
||||
true);
|
||||
// Adding the 2 state tensors.
|
||||
output_state_ = AddInput({TensorType_INT16, input_shapes[18],
|
||||
ranges[18].first, ranges[18].second},
|
||||
true);
|
||||
cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
|
||||
ranges[19].first, ranges[19].second},
|
||||
true);
|
||||
|
||||
// Layer norm weights.
|
||||
if (use_layer_norm) {
|
||||
@ -3062,8 +3058,6 @@ class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
|
||||
|
||||
int projection_weights_;
|
||||
int projection_bias_;
|
||||
int input_activation_state_;
|
||||
int input_cell_state_;
|
||||
|
||||
int intermediates_[12];
|
||||
|
||||
@ -3160,7 +3154,7 @@ TEST(LSTMIntegerOpModel8x8_8, CifgYesLayerNormNoYesProjectionNoPeephole) {
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{n_output}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
@ -3194,14 +3188,14 @@ TEST(LSTMIntegerOpModel8x8_8, CifgYesLayerNormNoYesProjectionNoPeephole) {
|
||||
{-0.5, 0.5}, // projection_weight tensor
|
||||
{-1, 1}, // projection_bias tensor
|
||||
|
||||
{-1.0, 32767.0 / 32768}, // activation_state tensor
|
||||
{-1.0, 32767.0 / 32768}, // output_state tensor
|
||||
{-1.0, 32767.0 / 32768}, // cell_state tensor
|
||||
|
||||
{-1.00001, 1.0}, // input_layer_norm_coefficient tensor
|
||||
{-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
|
||||
{-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
|
||||
{-1.00001, 1.0}, // output_layer_norm_coefficient tensor
|
||||
// Output scale is the same as input activation scale and only activation
|
||||
// Output scale is the same as output_state scale and only output_state
|
||||
// scale is used in the op, so this is only provided for clarity.
|
||||
{-1.0, 32767.0 / 32768}, // output tensor.
|
||||
};
|
||||
|
@ -317,20 +317,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
CheckInputTensorDimensions(context, node, n_input, n_output,
|
||||
n_cell, is_layer_norm_lstm));
|
||||
|
||||
// Get the pointer to output, activation_state and cell_state buffer tensors.
|
||||
// Get the pointer to output, output_state and cell_state buffer tensors.
|
||||
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
|
||||
GetVariableInput(context, node, lstm::full::kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
|
||||
// Check the shape of input state tensors.
|
||||
// These tensor may be 1D or 2D. It's fine as long as the total size is
|
||||
// correct.
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
|
||||
|
||||
// Resize the output tensors.
|
||||
@ -370,7 +370,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (IsHybridOp(input, input_to_output_weights)) {
|
||||
op_data->compute_row_sums = true;
|
||||
// Allocate temporary tensors to store quantized values of input,
|
||||
// activation_state and cell_state tensors.
|
||||
// output_state and cell_state tensors.
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
scratch_tensor_index + kInputQuantized;
|
||||
TfLiteTensor* input_quantized =
|
||||
@ -384,17 +384,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
node->temporaries->data[kOutputStateQuantized] =
|
||||
scratch_tensor_index + kOutputStateQuantized;
|
||||
TfLiteTensor* activation_state_quantized =
|
||||
TfLiteTensor* output_state_quantized =
|
||||
GetTemporary(context, node, kOutputStateQuantized);
|
||||
activation_state_quantized->type = input_to_output_weights->type;
|
||||
activation_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
|
||||
activation_state->dims)) {
|
||||
TfLiteIntArray* activation_state_quantized_size =
|
||||
TfLiteIntArrayCopy(activation_state->dims);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, activation_state_quantized,
|
||||
activation_state_quantized_size));
|
||||
output_state_quantized->type = input_to_output_weights->type;
|
||||
output_state_quantized->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
|
||||
output_state->dims)) {
|
||||
TfLiteIntArray* output_state_quantized_size =
|
||||
TfLiteIntArrayCopy(output_state->dims);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output_state_quantized,
|
||||
output_state_quantized_size));
|
||||
}
|
||||
node->temporaries->data[kCellStateQuantized] =
|
||||
scratch_tensor_index + kCellStateQuantized;
|
||||
@ -559,11 +559,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||
|
||||
TfLiteTensor* activation_state =
|
||||
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
TfLiteTensor* output_state =
|
||||
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
|
||||
GetVariableInput(context, node, lstm::full::kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
|
||||
const TfLiteTensor* input_layer_norm_coefficients =
|
||||
@ -613,14 +613,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
|
||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||
/*output_offset=*/0, scratch_buffer, output_state, cell_state,
|
||||
output);
|
||||
}
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
||||
TfLiteTensor* activation_state_quantized =
|
||||
TfLiteTensor* output_state_quantized =
|
||||
GetTemporary(context, node, /*index=*/2);
|
||||
TfLiteTensor* cell_state_quantized =
|
||||
GetTemporary(context, node, /*index=*/3);
|
||||
@ -652,10 +652,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
|
||||
/*output_offset=*/0, scratch_buffer, scaling_factors,
|
||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||
cell_state_quantized, activation_state, cell_state, accum_scratch,
|
||||
output, zero_points, row_sums, row_sums_size,
|
||||
&op_data->compute_row_sums,
|
||||
/*aux_input_quantized=*/nullptr, output_state_quantized,
|
||||
cell_state_quantized, output_state, cell_state, accum_scratch, output,
|
||||
zero_points, row_sums, row_sums_size, &op_data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
}
|
||||
default:
|
||||
|
@ -100,13 +100,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||
projection_bias_ = AddNullInput();
|
||||
}
|
||||
|
||||
// Adding the 2 input state tensors.
|
||||
input_activation_state_ =
|
||||
// Adding the 2 state tensors.
|
||||
output_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}},
|
||||
/*is_variable=*/true);
|
||||
input_cell_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
|
||||
/*is_variable=*/true);
|
||||
cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
|
||||
/*is_variable=*/true);
|
||||
|
||||
// Layer norm weights.
|
||||
if (is_layer_norm) {
|
||||
@ -256,8 +255,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||
int projection_weights_;
|
||||
int projection_bias_;
|
||||
|
||||
int input_activation_state_;
|
||||
int input_cell_state_;
|
||||
int output_state_;
|
||||
int cell_state_;
|
||||
|
||||
int input_layer_norm_coefficients_;
|
||||
int forget_layer_norm_coefficients_;
|
||||
@ -537,7 +536,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
});
|
||||
|
||||
@ -599,7 +598,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
});
|
||||
|
||||
@ -665,7 +664,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
},
|
||||
TensorType_UINT8, GetParam());
|
||||
@ -728,7 +727,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
},
|
||||
TensorType_INT8, GetParam());
|
||||
@ -840,7 +839,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
});
|
||||
|
||||
@ -901,7 +900,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
},
|
||||
TensorType_UINT8, GetParam());
|
||||
@ -964,7 +963,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
},
|
||||
TensorType_INT8, GetParam());
|
||||
@ -1626,7 +1625,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
});
|
||||
|
||||
@ -1695,7 +1694,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
},
|
||||
TensorType_UINT8, GetParam());
|
||||
@ -1766,7 +1765,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
},
|
||||
TensorType_INT8, GetParam());
|
||||
@ -2437,7 +2436,7 @@ TEST_F(NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{n_output}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
});
|
||||
|
||||
@ -2643,7 +2642,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
@ -2714,7 +2713,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_output}, // output_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
|
@ -302,9 +302,8 @@ TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||
int output_offset, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output, Logger* logger,
|
||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output, Logger* logger,
|
||||
const std::vector<int>& intermediate_tensor_indexes,
|
||||
ErrorReporter* error_reporter) {
|
||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||
@ -390,10 +389,10 @@ TfLiteStatus EvalFloat(
|
||||
GetTensorData<float>(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
GetTensorData<float>(activation_state),
|
||||
GetTensorData<float>(cell_state), input_gate_scratch,
|
||||
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
|
||||
output_ptr_time, logger, intermediate_tensor_indexes, error_reporter);
|
||||
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
||||
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
|
||||
output_gate_scratch, output_ptr_time, logger,
|
||||
intermediate_tensor_indexes, error_reporter);
|
||||
}
|
||||
} else {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
@ -414,9 +413,9 @@ TfLiteStatus EvalFloat(
|
||||
float* output_ptr = GetTensorData<float>(output) +
|
||||
time_offset * output_step + output_offset;
|
||||
|
||||
// Offset the {activation,cell}_state pointers to the right batch.
|
||||
float* activation_state_ptr = GetTensorData<float>(activation_state) +
|
||||
b * output_batch_leading_dim;
|
||||
// Offset the {output,cell}_state pointers to the right batch.
|
||||
float* output_state_ptr =
|
||||
GetTensorData<float>(output_state) + b * output_batch_leading_dim;
|
||||
float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
|
||||
// Offset the scratch pointers to the right batch.
|
||||
float* input_gate_scratch_ptr =
|
||||
@ -452,7 +451,7 @@ TfLiteStatus EvalFloat(
|
||||
GetTensorData<float>(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
|
||||
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
||||
output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
||||
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
|
||||
output_gate_scratch_ptr, output_ptr, logger,
|
||||
intermediate_tensor_indexes, error_reporter);
|
||||
@ -541,11 +540,11 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||
|
||||
TfLiteTensor* activation_state = GetVariableInput(
|
||||
context, node, ops::builtin::lstm::full::kInputActivationStateTensor);
|
||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||
TfLiteTensor* output_state = GetVariableInput(
|
||||
context, node, ops::builtin::lstm::full::kOutputStateTensor);
|
||||
TF_LITE_ENSURE(context, output_state != nullptr);
|
||||
TfLiteTensor* cell_state = GetVariableInput(
|
||||
context, node, ops::builtin::lstm::full::kInputCellStateTensor);
|
||||
context, node, ops::builtin::lstm::full::kCellStateTensor);
|
||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||
|
||||
TfLiteTensor* output =
|
||||
@ -574,8 +573,8 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||
projection_bias, params, /*forward_sequence=*/true,
|
||||
/*time_major=*/true,
|
||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||
output, logger, intermediate_tensor_indexes, error_reporter);
|
||||
/*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
|
||||
logger, intermediate_tensor_indexes, error_reporter);
|
||||
}
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
|
Loading…
Reference in New Issue
Block a user