Add support for batch-major input in the unidirectional LSTM Op.
PiperOrigin-RevId: 217406579
This commit is contained in:
		
							parent
							
								
									e6440a80c8
								
							
						
					
					
						commit
						a3f855aca2
					
				@ -187,10 +187,13 @@ typedef struct {
 | 
			
		||||
} TfLiteLSTMParams;
 | 
			
		||||
 | 
			
		||||
typedef struct {
 | 
			
		||||
  // Parameters for the LSTM kernel.
 | 
			
		||||
  // Parameters needed for the underlying LSTM.
 | 
			
		||||
  TfLiteFusedActivation activation;
 | 
			
		||||
  float cell_clip;
 | 
			
		||||
  float proj_clip;
 | 
			
		||||
 | 
			
		||||
  // If set to true then the first dimension is time, otherwise batch.
 | 
			
		||||
  bool time_major;
 | 
			
		||||
} TfLiteUnidirectionalSequenceLSTMParams;
 | 
			
		||||
 | 
			
		||||
typedef struct {
 | 
			
		||||
 | 
			
		||||
@ -399,11 +399,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
 | 
			
		||||
            parse_activation(seq_lstm_params->fused_activation_function());
 | 
			
		||||
        params->cell_clip = seq_lstm_params->cell_clip();
 | 
			
		||||
        params->proj_clip = seq_lstm_params->proj_clip();
 | 
			
		||||
        params->time_major = seq_lstm_params->time_major();
 | 
			
		||||
      }
 | 
			
		||||
      *builtin_data = reinterpret_cast<void*>(params);
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
 | 
			
		||||
      auto params =
 | 
			
		||||
          allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
 | 
			
		||||
 | 
			
		||||
@ -876,6 +876,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
      params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
 | 
			
		||||
  const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
 | 
			
		||||
 | 
			
		||||
  // TODO(mirkov): add batch_major support (http://b/117326122).
 | 
			
		||||
  switch (fw_input_to_output_weights->type) {
 | 
			
		||||
    case kTfLiteFloat32: {
 | 
			
		||||
      TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
 | 
			
		||||
@ -889,8 +890,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          fw_aux_input_to_output_weights, fw_input_gate_bias,
 | 
			
		||||
          fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
 | 
			
		||||
          fw_projection_weights, fw_projection_bias, &lstm_params,
 | 
			
		||||
          /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
 | 
			
		||||
          fw_activation_state, fw_cell_state, fw_output);
 | 
			
		||||
          /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
 | 
			
		||||
          fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
 | 
			
		||||
      TF_LITE_ENSURE_OK(context, fw_pass_status);
 | 
			
		||||
 | 
			
		||||
      TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
 | 
			
		||||
@ -904,8 +905,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          bw_aux_input_to_output_weights, bw_input_gate_bias,
 | 
			
		||||
          bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
 | 
			
		||||
          bw_projection_weights, bw_projection_bias, &lstm_params,
 | 
			
		||||
          /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
 | 
			
		||||
          bw_activation_state, bw_cell_state, actual_bw_output);
 | 
			
		||||
          /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
 | 
			
		||||
          bw_scratch_buffer, bw_activation_state, bw_cell_state,
 | 
			
		||||
          actual_bw_output);
 | 
			
		||||
      TF_LITE_ENSURE_OK(context, bw_pass_status);
 | 
			
		||||
      return kTfLiteOk;
 | 
			
		||||
    }
 | 
			
		||||
@ -942,11 +944,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          fw_aux_input_to_output_weights, fw_input_gate_bias,
 | 
			
		||||
          fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
 | 
			
		||||
          fw_projection_weights, fw_projection_bias, &lstm_params,
 | 
			
		||||
          /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
 | 
			
		||||
          scaling_factors, prod_scaling_factors, recovered_cell_weights,
 | 
			
		||||
          input_quantized, aux_input_quantized, fw_activation_state_quantized,
 | 
			
		||||
          fw_cell_state_quantized, fw_activation_state, fw_cell_state,
 | 
			
		||||
          fw_output);
 | 
			
		||||
          /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
 | 
			
		||||
          fw_scratch_buffer, scaling_factors, prod_scaling_factors,
 | 
			
		||||
          recovered_cell_weights, input_quantized, aux_input_quantized,
 | 
			
		||||
          fw_activation_state_quantized, fw_cell_state_quantized,
 | 
			
		||||
          fw_activation_state, fw_cell_state, fw_output);
 | 
			
		||||
      TF_LITE_ENSURE_OK(context, fw_pass_status);
 | 
			
		||||
 | 
			
		||||
      TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
 | 
			
		||||
@ -960,11 +962,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          bw_aux_input_to_output_weights, bw_input_gate_bias,
 | 
			
		||||
          bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
 | 
			
		||||
          bw_projection_weights, bw_projection_bias, &lstm_params,
 | 
			
		||||
          /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
 | 
			
		||||
          scaling_factors, prod_scaling_factors, recovered_cell_weights,
 | 
			
		||||
          input_quantized, aux_input_quantized, bw_activation_state_quantized,
 | 
			
		||||
          bw_cell_state_quantized, bw_activation_state, bw_cell_state,
 | 
			
		||||
          actual_bw_output);
 | 
			
		||||
          /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
 | 
			
		||||
          bw_scratch_buffer, scaling_factors, prod_scaling_factors,
 | 
			
		||||
          recovered_cell_weights, input_quantized, aux_input_quantized,
 | 
			
		||||
          bw_activation_state_quantized, bw_cell_state_quantized,
 | 
			
		||||
          bw_activation_state, bw_cell_state, actual_bw_output);
 | 
			
		||||
      TF_LITE_ENSURE_OK(context, bw_pass_status);
 | 
			
		||||
      return kTfLiteOk;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -497,6 +497,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
 | 
			
		||||
          forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
 | 
			
		||||
          projection_bias, params, /*forward_sequence=*/true,
 | 
			
		||||
          /*time_major=*/true,
 | 
			
		||||
          /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
 | 
			
		||||
          output);
 | 
			
		||||
    }
 | 
			
		||||
@ -524,8 +525,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
 | 
			
		||||
          forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
 | 
			
		||||
          projection_bias, params, /*forward_sequence=*/true,
 | 
			
		||||
          /*output_offset=*/0, scratch_buffer, scaling_factors,
 | 
			
		||||
          prod_scaling_factors, recovered_cell_weights, input_quantized,
 | 
			
		||||
          /*time_major=*/true, /*output_offset=*/0, scratch_buffer,
 | 
			
		||||
          scaling_factors, prod_scaling_factors, recovered_cell_weights,
 | 
			
		||||
          input_quantized,
 | 
			
		||||
          /*aux_input_quantized=*/nullptr, activation_state_quantized,
 | 
			
		||||
          cell_state_quantized, activation_state, cell_state, output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -710,9 +710,10 @@ TfLiteStatus EvalFloat(
 | 
			
		||||
    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
 | 
			
		||||
    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
 | 
			
		||||
    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
 | 
			
		||||
    TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
 | 
			
		||||
    TfLiteTensor* cell_state, TfLiteTensor* output) {
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
 | 
			
		||||
    int output_offset, TfLiteTensor* scratch_buffer,
 | 
			
		||||
    TfLiteTensor* activation_state, TfLiteTensor* cell_state,
 | 
			
		||||
    TfLiteTensor* output) {
 | 
			
		||||
  TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
 | 
			
		||||
  const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
 | 
			
		||||
  const int n_batch = input->dims->data[input->dims->size - 2];
 | 
			
		||||
@ -777,36 +778,71 @@ TfLiteStatus EvalFloat(
 | 
			
		||||
    aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Loop through the sequence.
 | 
			
		||||
  const int output_batch_leading_dim =
 | 
			
		||||
      output->dims->data[output->dims->size - 1];
 | 
			
		||||
  const int input_step = n_batch * n_input;
 | 
			
		||||
  const int output_step = n_batch * output_batch_leading_dim;
 | 
			
		||||
  for (int t = 0; t < max_time; t++) {
 | 
			
		||||
    // If this is the forward_sequence, step forward, otherwise step backwards.
 | 
			
		||||
    const int t_rel = forward_sequence ? t : max_time - t - 1;
 | 
			
		||||
    const float* input_ptr = input->data.f + t_rel * input_step;
 | 
			
		||||
    if (aux_input) {
 | 
			
		||||
      aux_input_ptr = aux_input->data.f + t_rel * input_step;
 | 
			
		||||
    }
 | 
			
		||||
    float* output_ptr_time =
 | 
			
		||||
        output->data.f + t_rel * output_step + output_offset;
 | 
			
		||||
  if (time_major) {
 | 
			
		||||
    // Loop through the sequence.
 | 
			
		||||
    const int input_step = n_batch * n_input;
 | 
			
		||||
    const int output_step = n_batch * output_batch_leading_dim;
 | 
			
		||||
    for (int t = 0; t < max_time; t++) {
 | 
			
		||||
      // If this is the forward_sequence, step forward, otherwise step
 | 
			
		||||
      // backwards.
 | 
			
		||||
      const int t_rel = forward_sequence ? t : max_time - t - 1;
 | 
			
		||||
      const float* input_ptr = input->data.f + t_rel * input_step;
 | 
			
		||||
      if (aux_input) {
 | 
			
		||||
        aux_input_ptr = aux_input->data.f + t_rel * input_step;
 | 
			
		||||
      }
 | 
			
		||||
      float* output_ptr_time =
 | 
			
		||||
          output->data.f + t_rel * output_step + output_offset;
 | 
			
		||||
 | 
			
		||||
    LstmStepWithAuxInput(
 | 
			
		||||
        input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
 | 
			
		||||
        input_to_cell_weights->data.f, input_to_output_weights->data.f,
 | 
			
		||||
        aux_input_ptr, aux_input_to_input_weights_ptr,
 | 
			
		||||
        aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
 | 
			
		||||
        aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
 | 
			
		||||
        recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
 | 
			
		||||
        recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
 | 
			
		||||
        cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
 | 
			
		||||
        input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
 | 
			
		||||
        output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
 | 
			
		||||
        params, n_batch, n_cell, n_input, aux_input_size, n_output,
 | 
			
		||||
        output_batch_leading_dim, activation_state->data.f, cell_state->data.f,
 | 
			
		||||
        input_gate_scratch, forget_gate_scratch, cell_scratch,
 | 
			
		||||
        output_gate_scratch, output_ptr_time);
 | 
			
		||||
      LstmStepWithAuxInput(
 | 
			
		||||
          input_ptr, input_to_input_weights_ptr,
 | 
			
		||||
          input_to_forget_weights->data.f, input_to_cell_weights->data.f,
 | 
			
		||||
          input_to_output_weights->data.f, aux_input_ptr,
 | 
			
		||||
          aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
 | 
			
		||||
          aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
 | 
			
		||||
          recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
 | 
			
		||||
          recurrent_to_cell_weights->data.f,
 | 
			
		||||
          recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
 | 
			
		||||
          cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
 | 
			
		||||
          input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
 | 
			
		||||
          output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
 | 
			
		||||
          params, n_batch, n_cell, n_input, aux_input_size, n_output,
 | 
			
		||||
          output_batch_leading_dim, activation_state->data.f,
 | 
			
		||||
          cell_state->data.f, input_gate_scratch, forget_gate_scratch,
 | 
			
		||||
          cell_scratch, output_gate_scratch, output_ptr_time);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    for (int b = 0; b < n_batch; b++) {
 | 
			
		||||
      const int input_step = n_input;
 | 
			
		||||
      const int output_step = output_batch_leading_dim;
 | 
			
		||||
      for (int t = 0; t < max_time; t++) {
 | 
			
		||||
        // If this is the forward_sequence, step forward, otherwise step
 | 
			
		||||
        // backwards.
 | 
			
		||||
        const int t_rel = forward_sequence ? t : max_time - t - 1;
 | 
			
		||||
        const float* input_ptr = input->data.f + t_rel * input_step;
 | 
			
		||||
        float* output_ptr_time =
 | 
			
		||||
            output->data.f + t_rel * output_step + output_offset;
 | 
			
		||||
 | 
			
		||||
        LstmStepWithAuxInput(
 | 
			
		||||
            input_ptr, input_to_input_weights_ptr,
 | 
			
		||||
            input_to_forget_weights->data.f, input_to_cell_weights->data.f,
 | 
			
		||||
            input_to_output_weights->data.f, aux_input_ptr,
 | 
			
		||||
            aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
 | 
			
		||||
            aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
 | 
			
		||||
            recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
 | 
			
		||||
            recurrent_to_cell_weights->data.f,
 | 
			
		||||
            recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
 | 
			
		||||
            cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
 | 
			
		||||
            input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
 | 
			
		||||
            output_gate_bias->data.f, projection_weights_ptr,
 | 
			
		||||
            projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
 | 
			
		||||
            aux_input_size, n_output, output_batch_leading_dim,
 | 
			
		||||
            activation_state->data.f, cell_state->data.f, input_gate_scratch,
 | 
			
		||||
            forget_gate_scratch, cell_scratch, output_gate_scratch,
 | 
			
		||||
            output_ptr_time);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return kTfLiteOk;
 | 
			
		||||
}
 | 
			
		||||
@ -830,13 +866,13 @@ TfLiteStatus EvalHybrid(
 | 
			
		||||
    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
 | 
			
		||||
    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
 | 
			
		||||
    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
 | 
			
		||||
    TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
 | 
			
		||||
    TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
 | 
			
		||||
    TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
 | 
			
		||||
    TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
 | 
			
		||||
    TfLiteTensor* output_state, TfLiteTensor* cell_state,
 | 
			
		||||
    TfLiteTensor* output) {
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
 | 
			
		||||
    int output_offset, TfLiteTensor* scratch_buffer,
 | 
			
		||||
    TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
 | 
			
		||||
    TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
 | 
			
		||||
    TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
 | 
			
		||||
    TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
 | 
			
		||||
    TfLiteTensor* cell_state, TfLiteTensor* output) {
 | 
			
		||||
  TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
 | 
			
		||||
  const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
 | 
			
		||||
  const int n_batch = input->dims->data[input->dims->size - 2];
 | 
			
		||||
@ -990,45 +1026,90 @@ TfLiteStatus EvalHybrid(
 | 
			
		||||
        aux_input_to_output_weights->params.scale;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Feed the sequence into the LSTM step-by-step.
 | 
			
		||||
  const int output_batch_leading_dim =
 | 
			
		||||
      output->dims->data[output->dims->size - 1];
 | 
			
		||||
  const int input_step = n_batch * n_input;
 | 
			
		||||
  const int output_step = n_batch * output_batch_leading_dim;
 | 
			
		||||
  for (int t = 0; t < max_time; t++) {
 | 
			
		||||
    // If this is the forward_sequence, step forward, otherwise step backwards.
 | 
			
		||||
    const int t_rel = forward_sequence ? t : max_time - t - 1;
 | 
			
		||||
    const float* input_ptr = input->data.f + t_rel * input_step;
 | 
			
		||||
    if (aux_input) {
 | 
			
		||||
      aux_input_ptr = aux_input->data.f + t_rel * input_step;
 | 
			
		||||
    }
 | 
			
		||||
    float* output_ptr = output->data.f + t_rel * output_step + output_offset;
 | 
			
		||||
  if (time_major) {
 | 
			
		||||
    // Feed the sequence into the LSTM step-by-step.
 | 
			
		||||
    const int input_step = n_batch * n_input;
 | 
			
		||||
    const int output_step = n_batch * output_batch_leading_dim;
 | 
			
		||||
    for (int t = 0; t < max_time; t++) {
 | 
			
		||||
      // If this is the forward_sequence, step forward, otherwise step
 | 
			
		||||
      // backwards.
 | 
			
		||||
      const int t_rel = forward_sequence ? t : max_time - t - 1;
 | 
			
		||||
      const float* input_ptr = input->data.f + t_rel * input_step;
 | 
			
		||||
      if (aux_input) {
 | 
			
		||||
        aux_input_ptr = aux_input->data.f + t_rel * input_step;
 | 
			
		||||
      }
 | 
			
		||||
      float* output_ptr = output->data.f + t_rel * output_step + output_offset;
 | 
			
		||||
 | 
			
		||||
    LstmStepWithAuxInput(
 | 
			
		||||
        input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
 | 
			
		||||
        input_to_forget_weights_ptr, input_to_forget_weights_scale,
 | 
			
		||||
        input_to_cell_weights_ptr, input_to_cell_weights_scale,
 | 
			
		||||
        input_to_output_weights_ptr, input_to_output_weights_scale,
 | 
			
		||||
        aux_input_ptr, aux_input_to_input_weights_ptr,
 | 
			
		||||
        aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
 | 
			
		||||
        aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
 | 
			
		||||
        aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
 | 
			
		||||
        aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
 | 
			
		||||
        recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
 | 
			
		||||
        recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
 | 
			
		||||
        recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
 | 
			
		||||
        recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
 | 
			
		||||
        cell_to_input_weights_scale, cell_to_forget_weights_ptr,
 | 
			
		||||
        cell_to_forget_weights_scale, cell_to_output_weights_ptr,
 | 
			
		||||
        cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
 | 
			
		||||
        cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
 | 
			
		||||
        projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
 | 
			
		||||
        n_input, aux_input_size, n_output, output_batch_leading_dim,
 | 
			
		||||
        input_gate_scratch, forget_gate_scratch, cell_scratch,
 | 
			
		||||
        output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
 | 
			
		||||
        recovered_cell_weights_ptr, quantized_input_ptr,
 | 
			
		||||
        quantized_aux_input_ptr, quantized_output_state_ptr,
 | 
			
		||||
        quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
 | 
			
		||||
      LstmStepWithAuxInput(
 | 
			
		||||
          input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
 | 
			
		||||
          input_to_forget_weights_ptr, input_to_forget_weights_scale,
 | 
			
		||||
          input_to_cell_weights_ptr, input_to_cell_weights_scale,
 | 
			
		||||
          input_to_output_weights_ptr, input_to_output_weights_scale,
 | 
			
		||||
          aux_input_ptr, aux_input_to_input_weights_ptr,
 | 
			
		||||
          aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
 | 
			
		||||
          aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
 | 
			
		||||
          aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
 | 
			
		||||
          aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
 | 
			
		||||
          recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
 | 
			
		||||
          recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
 | 
			
		||||
          recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
 | 
			
		||||
          recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
 | 
			
		||||
          cell_to_input_weights_scale, cell_to_forget_weights_ptr,
 | 
			
		||||
          cell_to_forget_weights_scale, cell_to_output_weights_ptr,
 | 
			
		||||
          cell_to_output_weights_scale, input_gate_bias_ptr,
 | 
			
		||||
          forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
 | 
			
		||||
          projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
 | 
			
		||||
          params, n_batch, n_cell, n_input, aux_input_size, n_output,
 | 
			
		||||
          output_batch_leading_dim, input_gate_scratch, forget_gate_scratch,
 | 
			
		||||
          cell_scratch, output_gate_scratch, scaling_factors_ptr,
 | 
			
		||||
          prod_scaling_factors_ptr, recovered_cell_weights_ptr,
 | 
			
		||||
          quantized_input_ptr, quantized_aux_input_ptr,
 | 
			
		||||
          quantized_output_state_ptr, quantized_cell_state_ptr,
 | 
			
		||||
          output_state_ptr, cell_state_ptr, output_ptr);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    for (int b = 0; b < n_batch; b++) {
 | 
			
		||||
      const int input_step = n_input;
 | 
			
		||||
      const int output_step = output_batch_leading_dim;
 | 
			
		||||
      for (int t = 0; t < max_time; t++) {
 | 
			
		||||
        // If this is the forward_sequence, step forward, otherwise step
 | 
			
		||||
        // backwards.
 | 
			
		||||
        const int t_rel = forward_sequence ? t : max_time - t - 1;
 | 
			
		||||
        const float* input_ptr = input->data.f + t_rel * input_step;
 | 
			
		||||
        float* output_ptr =
 | 
			
		||||
            output->data.f + t_rel * output_step + output_offset;
 | 
			
		||||
 | 
			
		||||
        LstmStepWithAuxInput(
 | 
			
		||||
            input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
 | 
			
		||||
            input_to_forget_weights_ptr, input_to_forget_weights_scale,
 | 
			
		||||
            input_to_cell_weights_ptr, input_to_cell_weights_scale,
 | 
			
		||||
            input_to_output_weights_ptr, input_to_output_weights_scale,
 | 
			
		||||
            aux_input_ptr, aux_input_to_input_weights_ptr,
 | 
			
		||||
            aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
 | 
			
		||||
            aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
 | 
			
		||||
            aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
 | 
			
		||||
            aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
 | 
			
		||||
            recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
 | 
			
		||||
            recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
 | 
			
		||||
            recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
 | 
			
		||||
            recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
 | 
			
		||||
            cell_to_input_weights_scale, cell_to_forget_weights_ptr,
 | 
			
		||||
            cell_to_forget_weights_scale, cell_to_output_weights_ptr,
 | 
			
		||||
            cell_to_output_weights_scale, input_gate_bias_ptr,
 | 
			
		||||
            forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
 | 
			
		||||
            projection_weights_ptr, projection_weights_scale,
 | 
			
		||||
            projection_bias_ptr, params, n_batch, n_cell, n_input,
 | 
			
		||||
            aux_input_size, n_output, output_batch_leading_dim,
 | 
			
		||||
            input_gate_scratch, forget_gate_scratch, cell_scratch,
 | 
			
		||||
            output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
 | 
			
		||||
            recovered_cell_weights_ptr, quantized_input_ptr,
 | 
			
		||||
            quantized_aux_input_ptr, quantized_output_state_ptr,
 | 
			
		||||
            quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
 | 
			
		||||
            output_ptr);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return kTfLiteOk;
 | 
			
		||||
 | 
			
		||||
@ -42,9 +42,10 @@ TfLiteStatus EvalFloat(
 | 
			
		||||
    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
 | 
			
		||||
    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
 | 
			
		||||
    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
 | 
			
		||||
    TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
 | 
			
		||||
    TfLiteTensor* cell_state, TfLiteTensor* output);
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
 | 
			
		||||
    int output_offset, TfLiteTensor* scratch_buffer,
 | 
			
		||||
    TfLiteTensor* activation_state, TfLiteTensor* cell_state,
 | 
			
		||||
    TfLiteTensor* output);
 | 
			
		||||
 | 
			
		||||
TfLiteStatus EvalHybrid(
 | 
			
		||||
    const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
 | 
			
		||||
@ -65,12 +66,13 @@ TfLiteStatus EvalHybrid(
 | 
			
		||||
    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
 | 
			
		||||
    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
 | 
			
		||||
    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
 | 
			
		||||
    TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
 | 
			
		||||
    TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
 | 
			
		||||
    TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
 | 
			
		||||
    TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
 | 
			
		||||
    TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output);
 | 
			
		||||
    const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
 | 
			
		||||
    int output_offset, TfLiteTensor* scratch_buffer,
 | 
			
		||||
    TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
 | 
			
		||||
    TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
 | 
			
		||||
    TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
 | 
			
		||||
    TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
 | 
			
		||||
    TfLiteTensor* cell_state, TfLiteTensor* output);
 | 
			
		||||
 | 
			
		||||
}  // namespace lstm_eval
 | 
			
		||||
}  // namespace builtin
 | 
			
		||||
 | 
			
		||||
@ -260,8 +260,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
 | 
			
		||||
  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
 | 
			
		||||
  TF_LITE_ENSURE(context, input->dims->size > 1);
 | 
			
		||||
  const int max_time = input->dims->data[0];
 | 
			
		||||
  const int n_batch = input->dims->data[1];
 | 
			
		||||
  const auto* params =
 | 
			
		||||
      reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
 | 
			
		||||
          node->builtin_data);
 | 
			
		||||
  const bool time_major = params->time_major;
 | 
			
		||||
  const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
 | 
			
		||||
  const int n_input = input->dims->data[2];
 | 
			
		||||
 | 
			
		||||
  const TfLiteTensor* input_to_output_weights =
 | 
			
		||||
@ -296,10 +299,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
  TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
 | 
			
		||||
 | 
			
		||||
  // Resize the output tensors.
 | 
			
		||||
  TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
 | 
			
		||||
  output_size->data[0] = max_time;
 | 
			
		||||
  output_size->data[1] = n_batch;
 | 
			
		||||
  output_size->data[2] = n_output;
 | 
			
		||||
  TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
 | 
			
		||||
  output_size->data[input->dims->size - 1] = n_output;
 | 
			
		||||
  TF_LITE_ENSURE_OK(context,
 | 
			
		||||
                    context->ResizeTensor(context, output, output_size));
 | 
			
		||||
 | 
			
		||||
@ -436,6 +437,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
  const auto* params =
 | 
			
		||||
      reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
 | 
			
		||||
          node->builtin_data);
 | 
			
		||||
  const bool time_major = params->time_major;
 | 
			
		||||
  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
 | 
			
		||||
 | 
			
		||||
  const TfLiteTensor* input_to_input_weights =
 | 
			
		||||
@ -506,7 +508,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          /*aux_input_to_cell_weights=*/nullptr,
 | 
			
		||||
          /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
 | 
			
		||||
          forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
 | 
			
		||||
          projection_bias, &lstm_params, /*forward_sequence=*/true,
 | 
			
		||||
          projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
 | 
			
		||||
          /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
 | 
			
		||||
          output);
 | 
			
		||||
    }
 | 
			
		||||
@ -533,7 +535,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
          /*aux_input_to_cell_weights=*/nullptr,
 | 
			
		||||
          /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
 | 
			
		||||
          forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
 | 
			
		||||
          projection_bias, &lstm_params, /*forward_sequence=*/true,
 | 
			
		||||
          projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
 | 
			
		||||
          /*output_offset=*/0, scratch_buffer, scaling_factors,
 | 
			
		||||
          prod_scaling_factors, recovered_cell_weights, input_quantized,
 | 
			
		||||
          /*aux_input_quantized=*/nullptr, activation_state_quantized,
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,7 @@ using ::testing::ElementsAreArray;
 | 
			
		||||
class UnidirectionalLSTMOpModel : public SingleOpModel {
 | 
			
		||||
 public:
 | 
			
		||||
  UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
 | 
			
		||||
                            int sequence_length, bool use_cifg,
 | 
			
		||||
                            int sequence_length, bool time_major, bool use_cifg,
 | 
			
		||||
                            bool use_peephole, bool use_projection_weights,
 | 
			
		||||
                            bool use_projection_bias, float cell_clip,
 | 
			
		||||
                            float proj_clip,
 | 
			
		||||
@ -110,12 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
 | 
			
		||||
 | 
			
		||||
    output_ = AddOutput(TensorType_FLOAT32);
 | 
			
		||||
 | 
			
		||||
    SetBuiltinOp(
 | 
			
		||||
        BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
 | 
			
		||||
        BuiltinOptions_UnidirectionalSequenceLSTMOptions,
 | 
			
		||||
        CreateUnidirectionalSequenceLSTMOptions(
 | 
			
		||||
            builder_, ActivationFunctionType_TANH, cell_clip, proj_clip)
 | 
			
		||||
            .Union());
 | 
			
		||||
    SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
 | 
			
		||||
                 BuiltinOptions_UnidirectionalSequenceLSTMOptions,
 | 
			
		||||
                 CreateUnidirectionalSequenceLSTMOptions(
 | 
			
		||||
                     builder_, ActivationFunctionType_TANH, cell_clip,
 | 
			
		||||
                     proj_clip, time_major)
 | 
			
		||||
                     .Union());
 | 
			
		||||
    BuildInterpreter(input_shapes);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -241,12 +241,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
 | 
			
		||||
 public:
 | 
			
		||||
  HybridUnidirectionalLSTMOpModel(
 | 
			
		||||
      int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
 | 
			
		||||
      bool use_cifg, bool use_peephole, bool use_projection_weights,
 | 
			
		||||
      bool use_projection_bias, float cell_clip, float proj_clip,
 | 
			
		||||
      const std::vector<std::vector<int>>& input_shapes)
 | 
			
		||||
      bool time_major, bool use_cifg, bool use_peephole,
 | 
			
		||||
      bool use_projection_weights, bool use_projection_bias, float cell_clip,
 | 
			
		||||
      float proj_clip, const std::vector<std::vector<int>>& input_shapes)
 | 
			
		||||
      : UnidirectionalLSTMOpModel(
 | 
			
		||||
            n_batch, n_input, n_cell, n_output, sequence_length, use_cifg,
 | 
			
		||||
            use_peephole, use_projection_weights, use_projection_bias,
 | 
			
		||||
            n_batch, n_input, n_cell, n_output, sequence_length, time_major,
 | 
			
		||||
            use_cifg, use_peephole, use_projection_weights, use_projection_bias,
 | 
			
		||||
            cell_clip, proj_clip, input_shapes, TensorType_UINT8) {}
 | 
			
		||||
 | 
			
		||||
  void SetInputToInputWeights(const std::vector<float>& f) {
 | 
			
		||||
@ -326,21 +326,32 @@ class BaseLstmTest : public ::testing::Test {
 | 
			
		||||
  // Compares output up to tolerance to the result of the lstm given the input.
 | 
			
		||||
  void VerifyGoldens(const std::vector<std::vector<float>>& input,
 | 
			
		||||
                     const std::vector<std::vector<float>>& output,
 | 
			
		||||
                     UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
 | 
			
		||||
                     UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5,
 | 
			
		||||
                     bool time_major = true) {
 | 
			
		||||
    const int num_batches = input.size();
 | 
			
		||||
    EXPECT_GT(num_batches, 0);
 | 
			
		||||
    const int num_inputs = lstm->num_inputs();
 | 
			
		||||
    EXPECT_GT(num_inputs, 0);
 | 
			
		||||
    const int input_sequence_size = input[0].size() / num_inputs;
 | 
			
		||||
    EXPECT_GT(input_sequence_size, 0);
 | 
			
		||||
    // Feed the whole sequence as input.
 | 
			
		||||
    for (int i = 0; i < input_sequence_size; ++i) {
 | 
			
		||||
      for (int b = 0; b < num_batches; ++b) {
 | 
			
		||||
        const float* batch_start = input[b].data() + i * num_inputs;
 | 
			
		||||
        const float* batch_end = batch_start + num_inputs;
 | 
			
		||||
    if (time_major) {
 | 
			
		||||
      // Feed the whole sequence as input.
 | 
			
		||||
      for (int i = 0; i < input_sequence_size; ++i) {
 | 
			
		||||
        for (int b = 0; b < num_batches; ++b) {
 | 
			
		||||
          const float* batch_start = input[b].data() + i * num_inputs;
 | 
			
		||||
          const float* batch_end = batch_start + num_inputs;
 | 
			
		||||
 | 
			
		||||
        lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(),
 | 
			
		||||
                       batch_start, batch_end);
 | 
			
		||||
          lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start,
 | 
			
		||||
                         batch_end);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      for (int b = 0; b < num_batches; ++b) {
 | 
			
		||||
        const float* batch_start = input[b].data();
 | 
			
		||||
        const float* batch_end = batch_start + input_sequence_size * num_inputs;
 | 
			
		||||
 | 
			
		||||
        lstm->SetInput(b * input_sequence_size * num_inputs, batch_start,
 | 
			
		||||
                       batch_end);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -349,15 +360,25 @@ class BaseLstmTest : public ::testing::Test {
 | 
			
		||||
    const int num_outputs = lstm->num_outputs();
 | 
			
		||||
    EXPECT_GT(num_outputs, 0);
 | 
			
		||||
    std::vector<float> expected;
 | 
			
		||||
    for (int i = 0; i < input_sequence_size; ++i) {
 | 
			
		||||
      for (int b = 0; b < num_batches; ++b) {
 | 
			
		||||
        const float* golden_start_batch = output[b].data() + i * num_outputs;
 | 
			
		||||
        const float* golden_end_batch = golden_start_batch + num_outputs;
 | 
			
		||||
 | 
			
		||||
        expected.insert(expected.end(), golden_start_batch, golden_end_batch);
 | 
			
		||||
    if (time_major) {
 | 
			
		||||
      for (int i = 0; i < input_sequence_size; ++i) {
 | 
			
		||||
        for (int b = 0; b < num_batches; ++b) {
 | 
			
		||||
          const float* golden_start_batch = output[b].data() + i * num_outputs;
 | 
			
		||||
          const float* golden_end_batch = golden_start_batch + num_outputs;
 | 
			
		||||
 | 
			
		||||
          expected.insert(expected.end(), golden_start_batch, golden_end_batch);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      for (int b = 0; b < num_batches; ++b) {
 | 
			
		||||
        const float* golden_batch_start = output[b].data();
 | 
			
		||||
        const float* golden_batch_end =
 | 
			
		||||
            golden_batch_start + input_sequence_size * num_outputs;
 | 
			
		||||
 | 
			
		||||
        expected.insert(expected.end(), golden_batch_start, golden_batch_end);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    EXPECT_THAT(lstm->GetOutput(),
 | 
			
		||||
                ElementsAreArray(ArrayFloatNear(expected, tolerance)));
 | 
			
		||||
  }
 | 
			
		||||
@ -422,7 +443,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
 | 
			
		||||
 | 
			
		||||
  UnidirectionalLSTMOpModel lstm(
 | 
			
		||||
      n_batch, n_input, n_cell, n_output, sequence_length,
 | 
			
		||||
      /*use_cifg=*/false, /*use_peephole=*/false,
 | 
			
		||||
      /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
 | 
			
		||||
      /*use_projection_weights=*/false,
 | 
			
		||||
      /*use_projection_bias=*/false,
 | 
			
		||||
      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
 | 
			
		||||
@ -473,6 +494,73 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
 | 
			
		||||
  VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
 | 
			
		||||
       LstmBlackBoxTestBatchMajor) {
 | 
			
		||||
  const int n_batch = 1;
 | 
			
		||||
  const int n_input = 2;
 | 
			
		||||
  // n_cell and n_output have the same size when there is no projection.
 | 
			
		||||
  const int n_cell = 4;
 | 
			
		||||
  const int n_output = 4;
 | 
			
		||||
  const int sequence_length = 3;
 | 
			
		||||
 | 
			
		||||
  UnidirectionalLSTMOpModel lstm(
 | 
			
		||||
      n_batch, n_input, n_cell, n_output, sequence_length,
 | 
			
		||||
      /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
 | 
			
		||||
      /*use_projection_weights=*/false,
 | 
			
		||||
      /*use_projection_bias=*/false,
 | 
			
		||||
      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
 | 
			
		||||
      {
 | 
			
		||||
          {sequence_length, n_batch, n_input},  // input tensor
 | 
			
		||||
 | 
			
		||||
          {n_cell, n_input},  // input_to_input_weight tensor
 | 
			
		||||
          {n_cell, n_input},  // input_to_forget_weight tensor
 | 
			
		||||
          {n_cell, n_input},  // input_to_cell_weight tensor
 | 
			
		||||
          {n_cell, n_input},  // input_to_output_weight tensor
 | 
			
		||||
 | 
			
		||||
          {n_cell, n_output},  // recurrent_to_input_weight tensor
 | 
			
		||||
          {n_cell, n_output},  // recurrent_to_forget_weight tensor
 | 
			
		||||
          {n_cell, n_output},  // recurrent_to_cell_weight tensor
 | 
			
		||||
          {n_cell, n_output},  // recurrent_to_output_weight tensor
 | 
			
		||||
 | 
			
		||||
          {0},  // cell_to_input_weight tensor
 | 
			
		||||
          {0},  // cell_to_forget_weight tensor
 | 
			
		||||
          {0},  // cell_to_output_weight tensor
 | 
			
		||||
 | 
			
		||||
          {n_cell},  // input_gate_bias tensor
 | 
			
		||||
          {n_cell},  // forget_gate_bias tensor
 | 
			
		||||
          {n_cell},  // cell_bias tensor
 | 
			
		||||
          {n_cell},  // output_gate_bias tensor
 | 
			
		||||
 | 
			
		||||
          {0, 0},  // projection_weight tensor
 | 
			
		||||
          {0},     // projection_bias tensor
 | 
			
		||||
 | 
			
		||||
          {n_batch, n_output},  // activation_state tensor
 | 
			
		||||
          {n_batch, n_cell},    // cell_state tensor
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
  lstm.SetInputToInputWeights(input_to_input_weights_);
 | 
			
		||||
  lstm.SetInputToCellWeights(input_to_cell_weights_);
 | 
			
		||||
  lstm.SetInputToForgetWeights(input_to_forget_weights_);
 | 
			
		||||
  lstm.SetInputToOutputWeights(input_to_output_weights_);
 | 
			
		||||
 | 
			
		||||
  lstm.SetInputGateBias(input_gate_bias_);
 | 
			
		||||
  lstm.SetCellBias(cell_gate_bias_);
 | 
			
		||||
  lstm.SetForgetGateBias(forget_gate_bias_);
 | 
			
		||||
  lstm.SetOutputGateBias(output_gate_bias_);
 | 
			
		||||
 | 
			
		||||
  lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
 | 
			
		||||
  lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
 | 
			
		||||
  lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
 | 
			
		||||
  lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
 | 
			
		||||
 | 
			
		||||
  // Reshuffle input and output to batch major format.
 | 
			
		||||
  std::vector<std::vector<float>> input;
 | 
			
		||||
  std::vector<std::vector<float>> output;
 | 
			
		||||
 | 
			
		||||
  VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/1e-5,
 | 
			
		||||
                /*time_major=*/false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
 | 
			
		||||
  const int n_batch = 1;
 | 
			
		||||
  const int n_input = 2;
 | 
			
		||||
@ -483,7 +571,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
 | 
			
		||||
 | 
			
		||||
  HybridUnidirectionalLSTMOpModel lstm(
 | 
			
		||||
      n_batch, n_input, n_cell, n_output, sequence_length,
 | 
			
		||||
      /*use_cifg=*/false, /*use_peephole=*/false,
 | 
			
		||||
      /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
 | 
			
		||||
      /*use_projection_weights=*/false,
 | 
			
		||||
      /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
 | 
			
		||||
      {
 | 
			
		||||
@ -591,7 +679,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
 | 
			
		||||
 | 
			
		||||
  UnidirectionalLSTMOpModel lstm(
 | 
			
		||||
      n_batch, n_input, n_cell, n_output, sequence_length,
 | 
			
		||||
      /*use_cifg=*/true, /*use_peephole=*/true,
 | 
			
		||||
      /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
 | 
			
		||||
      /*use_projection_weights=*/false,
 | 
			
		||||
      /*use_projection_bias=*/false,
 | 
			
		||||
      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
 | 
			
		||||
@ -652,7 +740,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
 | 
			
		||||
 | 
			
		||||
  HybridUnidirectionalLSTMOpModel lstm(
 | 
			
		||||
      n_batch, n_input, n_cell, n_output, sequence_length,
 | 
			
		||||
      /*use_cifg=*/true, /*use_peephole=*/true,
 | 
			
		||||
      /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
 | 
			
		||||
      /*use_projection_weights=*/false,
 | 
			
		||||
      /*use_projection_bias=*/false,
 | 
			
		||||
      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
 | 
			
		||||
@ -1311,7 +1399,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
 | 
			
		||||
 | 
			
		||||
  UnidirectionalLSTMOpModel lstm(
 | 
			
		||||
      n_batch, n_input, n_cell, n_output, sequence_length,
 | 
			
		||||
      /*use_cifg=*/false, /*use_peephole=*/true,
 | 
			
		||||
      /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
 | 
			
		||||
      /*use_projection_weights=*/true,
 | 
			
		||||
      /*use_projection_bias=*/false,
 | 
			
		||||
      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
 | 
			
		||||
@ -1377,7 +1465,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
 | 
			
		||||
 | 
			
		||||
  HybridUnidirectionalLSTMOpModel lstm(
 | 
			
		||||
      n_batch, n_input, n_cell, n_output, sequence_length,
 | 
			
		||||
      /*use_cifg=*/false, /*use_peephole=*/true,
 | 
			
		||||
      /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
 | 
			
		||||
      /*use_projection_weights=*/true,
 | 
			
		||||
      /*use_projection_bias=*/false,
 | 
			
		||||
      /*cell_clip=*/0.0, /*proj_clip=*/0.0,
 | 
			
		||||
 | 
			
		||||
@ -407,6 +407,9 @@ table UnidirectionalSequenceLSTMOptions {
 | 
			
		||||
  fused_activation_function:ActivationFunctionType;
 | 
			
		||||
  cell_clip: float; // Optional, 0.0 means no clipping
 | 
			
		||||
  proj_clip: float; // Optional, 0.0 means no clipping
 | 
			
		||||
 | 
			
		||||
  // If true then first dimension is sequence, otherwise batch.
 | 
			
		||||
  time_major:bool;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
table BidirectionalSequenceLSTMOptions {
 | 
			
		||||
 | 
			
		||||
@ -3534,10 +3534,12 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
 | 
			
		||||
  ActivationFunctionType fused_activation_function;
 | 
			
		||||
  float cell_clip;
 | 
			
		||||
  float proj_clip;
 | 
			
		||||
  bool time_major;
 | 
			
		||||
  UnidirectionalSequenceLSTMOptionsT()
 | 
			
		||||
      : fused_activation_function(ActivationFunctionType_NONE),
 | 
			
		||||
        cell_clip(0.0f),
 | 
			
		||||
        proj_clip(0.0f) {
 | 
			
		||||
        proj_clip(0.0f),
 | 
			
		||||
        time_major(false) {
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -3546,7 +3548,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
 | 
			
		||||
  enum {
 | 
			
		||||
    VT_FUSED_ACTIVATION_FUNCTION = 4,
 | 
			
		||||
    VT_CELL_CLIP = 6,
 | 
			
		||||
    VT_PROJ_CLIP = 8
 | 
			
		||||
    VT_PROJ_CLIP = 8,
 | 
			
		||||
    VT_TIME_MAJOR = 10
 | 
			
		||||
  };
 | 
			
		||||
  ActivationFunctionType fused_activation_function() const {
 | 
			
		||||
    return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
 | 
			
		||||
@ -3557,11 +3560,15 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
 | 
			
		||||
  float proj_clip() const {
 | 
			
		||||
    return GetField<float>(VT_PROJ_CLIP, 0.0f);
 | 
			
		||||
  }
 | 
			
		||||
  bool time_major() const {
 | 
			
		||||
    return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
 | 
			
		||||
  }
 | 
			
		||||
  bool Verify(flatbuffers::Verifier &verifier) const {
 | 
			
		||||
    return VerifyTableStart(verifier) &&
 | 
			
		||||
           VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
 | 
			
		||||
           VerifyField<float>(verifier, VT_CELL_CLIP) &&
 | 
			
		||||
           VerifyField<float>(verifier, VT_PROJ_CLIP) &&
 | 
			
		||||
           VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
 | 
			
		||||
           verifier.EndTable();
 | 
			
		||||
  }
 | 
			
		||||
  UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
 | 
			
		||||
@ -3581,6 +3588,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder {
 | 
			
		||||
  void add_proj_clip(float proj_clip) {
 | 
			
		||||
    fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
 | 
			
		||||
  }
 | 
			
		||||
  void add_time_major(bool time_major) {
 | 
			
		||||
    fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
 | 
			
		||||
  }
 | 
			
		||||
  explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
 | 
			
		||||
        : fbb_(_fbb) {
 | 
			
		||||
    start_ = fbb_.StartTable();
 | 
			
		||||
@ -3597,10 +3607,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
 | 
			
		||||
    flatbuffers::FlatBufferBuilder &_fbb,
 | 
			
		||||
    ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
 | 
			
		||||
    float cell_clip = 0.0f,
 | 
			
		||||
    float proj_clip = 0.0f) {
 | 
			
		||||
    float proj_clip = 0.0f,
 | 
			
		||||
    bool time_major = false) {
 | 
			
		||||
  UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
 | 
			
		||||
  builder_.add_proj_clip(proj_clip);
 | 
			
		||||
  builder_.add_cell_clip(cell_clip);
 | 
			
		||||
  builder_.add_time_major(time_major);
 | 
			
		||||
  builder_.add_fused_activation_function(fused_activation_function);
 | 
			
		||||
  return builder_.Finish();
 | 
			
		||||
}
 | 
			
		||||
@ -8060,6 +8072,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS
 | 
			
		||||
  { auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
 | 
			
		||||
  { auto _e = cell_clip(); _o->cell_clip = _e; };
 | 
			
		||||
  { auto _e = proj_clip(); _o->proj_clip = _e; };
 | 
			
		||||
  { auto _e = time_major(); _o->time_major = _e; };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
 | 
			
		||||
@ -8073,11 +8086,13 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
 | 
			
		||||
  auto _fused_activation_function = _o->fused_activation_function;
 | 
			
		||||
  auto _cell_clip = _o->cell_clip;
 | 
			
		||||
  auto _proj_clip = _o->proj_clip;
 | 
			
		||||
  auto _time_major = _o->time_major;
 | 
			
		||||
  return tflite::CreateUnidirectionalSequenceLSTMOptions(
 | 
			
		||||
      _fbb,
 | 
			
		||||
      _fused_activation_function,
 | 
			
		||||
      _cell_clip,
 | 
			
		||||
      _proj_clip);
 | 
			
		||||
      _proj_clip,
 | 
			
		||||
      _time_major);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user