Add support for batch-major input in the unidirectional LSTM Op.
PiperOrigin-RevId: 217406579
This commit is contained in:
		
							parent
							
								
									e6440a80c8
								
							
						
					
					
						commit
						a3f855aca2
					
				| @ -187,10 +187,13 @@ typedef struct { | ||||
| } TfLiteLSTMParams; | ||||
| 
 | ||||
| typedef struct { | ||||
|   // Parameters for the LSTM kernel.
 | ||||
|   // Parameters needed for the underlying LSTM.
 | ||||
|   TfLiteFusedActivation activation; | ||||
|   float cell_clip; | ||||
|   float proj_clip; | ||||
| 
 | ||||
|   // If set to true then the first dimension is time, otherwise batch.
 | ||||
|   bool time_major; | ||||
| } TfLiteUnidirectionalSequenceLSTMParams; | ||||
| 
 | ||||
| typedef struct { | ||||
|  | ||||
| @ -399,11 +399,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, | ||||
|             parse_activation(seq_lstm_params->fused_activation_function()); | ||||
|         params->cell_clip = seq_lstm_params->cell_clip(); | ||||
|         params->proj_clip = seq_lstm_params->proj_clip(); | ||||
|         params->time_major = seq_lstm_params->time_major(); | ||||
|       } | ||||
|       *builtin_data = reinterpret_cast<void*>(params); | ||||
|       break; | ||||
|     } | ||||
| 
 | ||||
|     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { | ||||
|       auto params = | ||||
|           allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>(); | ||||
|  | ||||
| @ -876,6 +876,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|       params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; | ||||
|   const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output; | ||||
| 
 | ||||
|   // TODO(mirkov): add batch_major support (http://b/117326122).
 | ||||
|   switch (fw_input_to_output_weights->type) { | ||||
|     case kTfLiteFloat32: { | ||||
|       TfLiteStatus fw_pass_status = lstm_eval::EvalFloat( | ||||
| @ -889,8 +890,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           fw_aux_input_to_output_weights, fw_input_gate_bias, | ||||
|           fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, | ||||
|           fw_projection_weights, fw_projection_bias, &lstm_params, | ||||
|           /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, | ||||
|           fw_activation_state, fw_cell_state, fw_output); | ||||
|           /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, | ||||
|           fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); | ||||
|       TF_LITE_ENSURE_OK(context, fw_pass_status); | ||||
| 
 | ||||
|       TfLiteStatus bw_pass_status = lstm_eval::EvalFloat( | ||||
| @ -904,8 +905,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           bw_aux_input_to_output_weights, bw_input_gate_bias, | ||||
|           bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, | ||||
|           bw_projection_weights, bw_projection_bias, &lstm_params, | ||||
|           /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, | ||||
|           bw_activation_state, bw_cell_state, actual_bw_output); | ||||
|           /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, | ||||
|           bw_scratch_buffer, bw_activation_state, bw_cell_state, | ||||
|           actual_bw_output); | ||||
|       TF_LITE_ENSURE_OK(context, bw_pass_status); | ||||
|       return kTfLiteOk; | ||||
|     } | ||||
| @ -942,11 +944,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           fw_aux_input_to_output_weights, fw_input_gate_bias, | ||||
|           fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, | ||||
|           fw_projection_weights, fw_projection_bias, &lstm_params, | ||||
|           /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, | ||||
|           scaling_factors, prod_scaling_factors, recovered_cell_weights, | ||||
|           input_quantized, aux_input_quantized, fw_activation_state_quantized, | ||||
|           fw_cell_state_quantized, fw_activation_state, fw_cell_state, | ||||
|           fw_output); | ||||
|           /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, | ||||
|           fw_scratch_buffer, scaling_factors, prod_scaling_factors, | ||||
|           recovered_cell_weights, input_quantized, aux_input_quantized, | ||||
|           fw_activation_state_quantized, fw_cell_state_quantized, | ||||
|           fw_activation_state, fw_cell_state, fw_output); | ||||
|       TF_LITE_ENSURE_OK(context, fw_pass_status); | ||||
| 
 | ||||
|       TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( | ||||
| @ -960,11 +962,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           bw_aux_input_to_output_weights, bw_input_gate_bias, | ||||
|           bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, | ||||
|           bw_projection_weights, bw_projection_bias, &lstm_params, | ||||
|           /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, | ||||
|           scaling_factors, prod_scaling_factors, recovered_cell_weights, | ||||
|           input_quantized, aux_input_quantized, bw_activation_state_quantized, | ||||
|           bw_cell_state_quantized, bw_activation_state, bw_cell_state, | ||||
|           actual_bw_output); | ||||
|           /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, | ||||
|           bw_scratch_buffer, scaling_factors, prod_scaling_factors, | ||||
|           recovered_cell_weights, input_quantized, aux_input_quantized, | ||||
|           bw_activation_state_quantized, bw_cell_state_quantized, | ||||
|           bw_activation_state, bw_cell_state, actual_bw_output); | ||||
|       TF_LITE_ENSURE_OK(context, bw_pass_status); | ||||
|       return kTfLiteOk; | ||||
|     } | ||||
|  | ||||
| @ -497,6 +497,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           /*aux_input_to_output_weights=*/nullptr, input_gate_bias, | ||||
|           forget_gate_bias, cell_bias, output_gate_bias, projection_weights, | ||||
|           projection_bias, params, /*forward_sequence=*/true, | ||||
|           /*time_major=*/true, | ||||
|           /*output_offset=*/0, scratch_buffer, activation_state, cell_state, | ||||
|           output); | ||||
|     } | ||||
| @ -524,8 +525,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           /*aux_input_to_output_weights=*/nullptr, input_gate_bias, | ||||
|           forget_gate_bias, cell_bias, output_gate_bias, projection_weights, | ||||
|           projection_bias, params, /*forward_sequence=*/true, | ||||
|           /*output_offset=*/0, scratch_buffer, scaling_factors, | ||||
|           prod_scaling_factors, recovered_cell_weights, input_quantized, | ||||
|           /*time_major=*/true, /*output_offset=*/0, scratch_buffer, | ||||
|           scaling_factors, prod_scaling_factors, recovered_cell_weights, | ||||
|           input_quantized, | ||||
|           /*aux_input_quantized=*/nullptr, activation_state_quantized, | ||||
|           cell_state_quantized, activation_state, cell_state, output); | ||||
|     } | ||||
|  | ||||
| @ -710,9 +710,10 @@ TfLiteStatus EvalFloat( | ||||
|     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, | ||||
|     const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, | ||||
|     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, | ||||
|     TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, | ||||
|     TfLiteTensor* cell_state, TfLiteTensor* output) { | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, | ||||
|     int output_offset, TfLiteTensor* scratch_buffer, | ||||
|     TfLiteTensor* activation_state, TfLiteTensor* cell_state, | ||||
|     TfLiteTensor* output) { | ||||
|   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); | ||||
|   const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; | ||||
|   const int n_batch = input->dims->data[input->dims->size - 2]; | ||||
| @ -777,36 +778,71 @@ TfLiteStatus EvalFloat( | ||||
|     aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f; | ||||
|   } | ||||
| 
 | ||||
|   // Loop through the sequence.
 | ||||
|   const int output_batch_leading_dim = | ||||
|       output->dims->data[output->dims->size - 1]; | ||||
|   const int input_step = n_batch * n_input; | ||||
|   const int output_step = n_batch * output_batch_leading_dim; | ||||
|   for (int t = 0; t < max_time; t++) { | ||||
|     // If this is the forward_sequence, step forward, otherwise step backwards.
 | ||||
|     const int t_rel = forward_sequence ? t : max_time - t - 1; | ||||
|     const float* input_ptr = input->data.f + t_rel * input_step; | ||||
|     if (aux_input) { | ||||
|       aux_input_ptr = aux_input->data.f + t_rel * input_step; | ||||
|     } | ||||
|     float* output_ptr_time = | ||||
|         output->data.f + t_rel * output_step + output_offset; | ||||
|   if (time_major) { | ||||
|     // Loop through the sequence.
 | ||||
|     const int input_step = n_batch * n_input; | ||||
|     const int output_step = n_batch * output_batch_leading_dim; | ||||
|     for (int t = 0; t < max_time; t++) { | ||||
|       // If this is the forward_sequence, step forward, otherwise step
 | ||||
|       // backwards.
 | ||||
|       const int t_rel = forward_sequence ? t : max_time - t - 1; | ||||
|       const float* input_ptr = input->data.f + t_rel * input_step; | ||||
|       if (aux_input) { | ||||
|         aux_input_ptr = aux_input->data.f + t_rel * input_step; | ||||
|       } | ||||
|       float* output_ptr_time = | ||||
|           output->data.f + t_rel * output_step + output_offset; | ||||
| 
 | ||||
|     LstmStepWithAuxInput( | ||||
|         input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, | ||||
|         input_to_cell_weights->data.f, input_to_output_weights->data.f, | ||||
|         aux_input_ptr, aux_input_to_input_weights_ptr, | ||||
|         aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr, | ||||
|         aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr, | ||||
|         recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, | ||||
|         recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, | ||||
|         cell_to_forget_weights_ptr, cell_to_output_weights_ptr, | ||||
|         input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, | ||||
|         output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, | ||||
|         params, n_batch, n_cell, n_input, aux_input_size, n_output, | ||||
|         output_batch_leading_dim, activation_state->data.f, cell_state->data.f, | ||||
|         input_gate_scratch, forget_gate_scratch, cell_scratch, | ||||
|         output_gate_scratch, output_ptr_time); | ||||
|       LstmStepWithAuxInput( | ||||
|           input_ptr, input_to_input_weights_ptr, | ||||
|           input_to_forget_weights->data.f, input_to_cell_weights->data.f, | ||||
|           input_to_output_weights->data.f, aux_input_ptr, | ||||
|           aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, | ||||
|           aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, | ||||
|           recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, | ||||
|           recurrent_to_cell_weights->data.f, | ||||
|           recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, | ||||
|           cell_to_forget_weights_ptr, cell_to_output_weights_ptr, | ||||
|           input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, | ||||
|           output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, | ||||
|           params, n_batch, n_cell, n_input, aux_input_size, n_output, | ||||
|           output_batch_leading_dim, activation_state->data.f, | ||||
|           cell_state->data.f, input_gate_scratch, forget_gate_scratch, | ||||
|           cell_scratch, output_gate_scratch, output_ptr_time); | ||||
|     } | ||||
|   } else { | ||||
|     for (int b = 0; b < n_batch; b++) { | ||||
|       const int input_step = n_input; | ||||
|       const int output_step = output_batch_leading_dim; | ||||
|       for (int t = 0; t < max_time; t++) { | ||||
|         // If this is the forward_sequence, step forward, otherwise step
 | ||||
|         // backwards.
 | ||||
|         const int t_rel = forward_sequence ? t : max_time - t - 1; | ||||
|         const float* input_ptr = input->data.f + t_rel * input_step; | ||||
|         float* output_ptr_time = | ||||
|             output->data.f + t_rel * output_step + output_offset; | ||||
| 
 | ||||
|         LstmStepWithAuxInput( | ||||
|             input_ptr, input_to_input_weights_ptr, | ||||
|             input_to_forget_weights->data.f, input_to_cell_weights->data.f, | ||||
|             input_to_output_weights->data.f, aux_input_ptr, | ||||
|             aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, | ||||
|             aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, | ||||
|             recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, | ||||
|             recurrent_to_cell_weights->data.f, | ||||
|             recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, | ||||
|             cell_to_forget_weights_ptr, cell_to_output_weights_ptr, | ||||
|             input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, | ||||
|             output_gate_bias->data.f, projection_weights_ptr, | ||||
|             projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input, | ||||
|             aux_input_size, n_output, output_batch_leading_dim, | ||||
|             activation_state->data.f, cell_state->data.f, input_gate_scratch, | ||||
|             forget_gate_scratch, cell_scratch, output_gate_scratch, | ||||
|             output_ptr_time); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return kTfLiteOk; | ||||
| } | ||||
| @ -830,13 +866,13 @@ TfLiteStatus EvalHybrid( | ||||
|     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, | ||||
|     const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, | ||||
|     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, | ||||
|     TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, | ||||
|     TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, | ||||
|     TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, | ||||
|     TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, | ||||
|     TfLiteTensor* output_state, TfLiteTensor* cell_state, | ||||
|     TfLiteTensor* output) { | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, | ||||
|     int output_offset, TfLiteTensor* scratch_buffer, | ||||
|     TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, | ||||
|     TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, | ||||
|     TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, | ||||
|     TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, | ||||
|     TfLiteTensor* cell_state, TfLiteTensor* output) { | ||||
|   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); | ||||
|   const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; | ||||
|   const int n_batch = input->dims->data[input->dims->size - 2]; | ||||
| @ -990,45 +1026,90 @@ TfLiteStatus EvalHybrid( | ||||
|         aux_input_to_output_weights->params.scale; | ||||
|   } | ||||
| 
 | ||||
|   // Feed the sequence into the LSTM step-by-step.
 | ||||
|   const int output_batch_leading_dim = | ||||
|       output->dims->data[output->dims->size - 1]; | ||||
|   const int input_step = n_batch * n_input; | ||||
|   const int output_step = n_batch * output_batch_leading_dim; | ||||
|   for (int t = 0; t < max_time; t++) { | ||||
|     // If this is the forward_sequence, step forward, otherwise step backwards.
 | ||||
|     const int t_rel = forward_sequence ? t : max_time - t - 1; | ||||
|     const float* input_ptr = input->data.f + t_rel * input_step; | ||||
|     if (aux_input) { | ||||
|       aux_input_ptr = aux_input->data.f + t_rel * input_step; | ||||
|     } | ||||
|     float* output_ptr = output->data.f + t_rel * output_step + output_offset; | ||||
|   if (time_major) { | ||||
|     // Feed the sequence into the LSTM step-by-step.
 | ||||
|     const int input_step = n_batch * n_input; | ||||
|     const int output_step = n_batch * output_batch_leading_dim; | ||||
|     for (int t = 0; t < max_time; t++) { | ||||
|       // If this is the forward_sequence, step forward, otherwise step
 | ||||
|       // backwards.
 | ||||
|       const int t_rel = forward_sequence ? t : max_time - t - 1; | ||||
|       const float* input_ptr = input->data.f + t_rel * input_step; | ||||
|       if (aux_input) { | ||||
|         aux_input_ptr = aux_input->data.f + t_rel * input_step; | ||||
|       } | ||||
|       float* output_ptr = output->data.f + t_rel * output_step + output_offset; | ||||
| 
 | ||||
|     LstmStepWithAuxInput( | ||||
|         input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, | ||||
|         input_to_forget_weights_ptr, input_to_forget_weights_scale, | ||||
|         input_to_cell_weights_ptr, input_to_cell_weights_scale, | ||||
|         input_to_output_weights_ptr, input_to_output_weights_scale, | ||||
|         aux_input_ptr, aux_input_to_input_weights_ptr, | ||||
|         aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, | ||||
|         aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, | ||||
|         aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, | ||||
|         aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, | ||||
|         recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, | ||||
|         recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, | ||||
|         recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, | ||||
|         recurrent_to_output_weights_scale, cell_to_input_weights_ptr, | ||||
|         cell_to_input_weights_scale, cell_to_forget_weights_ptr, | ||||
|         cell_to_forget_weights_scale, cell_to_output_weights_ptr, | ||||
|         cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, | ||||
|         cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, | ||||
|         projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, | ||||
|         n_input, aux_input_size, n_output, output_batch_leading_dim, | ||||
|         input_gate_scratch, forget_gate_scratch, cell_scratch, | ||||
|         output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, | ||||
|         recovered_cell_weights_ptr, quantized_input_ptr, | ||||
|         quantized_aux_input_ptr, quantized_output_state_ptr, | ||||
|         quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); | ||||
|       LstmStepWithAuxInput( | ||||
|           input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, | ||||
|           input_to_forget_weights_ptr, input_to_forget_weights_scale, | ||||
|           input_to_cell_weights_ptr, input_to_cell_weights_scale, | ||||
|           input_to_output_weights_ptr, input_to_output_weights_scale, | ||||
|           aux_input_ptr, aux_input_to_input_weights_ptr, | ||||
|           aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, | ||||
|           aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, | ||||
|           aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, | ||||
|           aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, | ||||
|           recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, | ||||
|           recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, | ||||
|           recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, | ||||
|           recurrent_to_output_weights_scale, cell_to_input_weights_ptr, | ||||
|           cell_to_input_weights_scale, cell_to_forget_weights_ptr, | ||||
|           cell_to_forget_weights_scale, cell_to_output_weights_ptr, | ||||
|           cell_to_output_weights_scale, input_gate_bias_ptr, | ||||
|           forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, | ||||
|           projection_weights_ptr, projection_weights_scale, projection_bias_ptr, | ||||
|           params, n_batch, n_cell, n_input, aux_input_size, n_output, | ||||
|           output_batch_leading_dim, input_gate_scratch, forget_gate_scratch, | ||||
|           cell_scratch, output_gate_scratch, scaling_factors_ptr, | ||||
|           prod_scaling_factors_ptr, recovered_cell_weights_ptr, | ||||
|           quantized_input_ptr, quantized_aux_input_ptr, | ||||
|           quantized_output_state_ptr, quantized_cell_state_ptr, | ||||
|           output_state_ptr, cell_state_ptr, output_ptr); | ||||
|     } | ||||
|   } else { | ||||
|     for (int b = 0; b < n_batch; b++) { | ||||
|       const int input_step = n_input; | ||||
|       const int output_step = output_batch_leading_dim; | ||||
|       for (int t = 0; t < max_time; t++) { | ||||
|         // If this is the forward_sequence, step forward, otherwise step
 | ||||
|         // backwards.
 | ||||
|         const int t_rel = forward_sequence ? t : max_time - t - 1; | ||||
|         const float* input_ptr = input->data.f + t_rel * input_step; | ||||
|         float* output_ptr = | ||||
|             output->data.f + t_rel * output_step + output_offset; | ||||
| 
 | ||||
|         LstmStepWithAuxInput( | ||||
|             input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, | ||||
|             input_to_forget_weights_ptr, input_to_forget_weights_scale, | ||||
|             input_to_cell_weights_ptr, input_to_cell_weights_scale, | ||||
|             input_to_output_weights_ptr, input_to_output_weights_scale, | ||||
|             aux_input_ptr, aux_input_to_input_weights_ptr, | ||||
|             aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, | ||||
|             aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, | ||||
|             aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, | ||||
|             aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, | ||||
|             recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, | ||||
|             recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, | ||||
|             recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, | ||||
|             recurrent_to_output_weights_scale, cell_to_input_weights_ptr, | ||||
|             cell_to_input_weights_scale, cell_to_forget_weights_ptr, | ||||
|             cell_to_forget_weights_scale, cell_to_output_weights_ptr, | ||||
|             cell_to_output_weights_scale, input_gate_bias_ptr, | ||||
|             forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, | ||||
|             projection_weights_ptr, projection_weights_scale, | ||||
|             projection_bias_ptr, params, n_batch, n_cell, n_input, | ||||
|             aux_input_size, n_output, output_batch_leading_dim, | ||||
|             input_gate_scratch, forget_gate_scratch, cell_scratch, | ||||
|             output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, | ||||
|             recovered_cell_weights_ptr, quantized_input_ptr, | ||||
|             quantized_aux_input_ptr, quantized_output_state_ptr, | ||||
|             quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, | ||||
|             output_ptr); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   return kTfLiteOk; | ||||
|  | ||||
| @ -42,9 +42,10 @@ TfLiteStatus EvalFloat( | ||||
|     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, | ||||
|     const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, | ||||
|     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, | ||||
|     TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, | ||||
|     TfLiteTensor* cell_state, TfLiteTensor* output); | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, | ||||
|     int output_offset, TfLiteTensor* scratch_buffer, | ||||
|     TfLiteTensor* activation_state, TfLiteTensor* cell_state, | ||||
|     TfLiteTensor* output); | ||||
| 
 | ||||
| TfLiteStatus EvalHybrid( | ||||
|     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, | ||||
| @ -65,12 +66,13 @@ TfLiteStatus EvalHybrid( | ||||
|     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, | ||||
|     const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, | ||||
|     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, | ||||
|     TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, | ||||
|     TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, | ||||
|     TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, | ||||
|     TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, | ||||
|     TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output); | ||||
|     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, | ||||
|     int output_offset, TfLiteTensor* scratch_buffer, | ||||
|     TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, | ||||
|     TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, | ||||
|     TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, | ||||
|     TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, | ||||
|     TfLiteTensor* cell_state, TfLiteTensor* output); | ||||
| 
 | ||||
| }  // namespace lstm_eval
 | ||||
| }  // namespace builtin
 | ||||
|  | ||||
| @ -260,8 +260,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { | ||||
|   const TfLiteTensor* input = GetInput(context, node, kInputTensor); | ||||
|   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); | ||||
|   TF_LITE_ENSURE(context, input->dims->size > 1); | ||||
|   const int max_time = input->dims->data[0]; | ||||
|   const int n_batch = input->dims->data[1]; | ||||
|   const auto* params = | ||||
|       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( | ||||
|           node->builtin_data); | ||||
|   const bool time_major = params->time_major; | ||||
|   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; | ||||
|   const int n_input = input->dims->data[2]; | ||||
| 
 | ||||
|   const TfLiteTensor* input_to_output_weights = | ||||
| @ -296,10 +299,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { | ||||
|   TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); | ||||
| 
 | ||||
|   // Resize the output tensors.
 | ||||
|   TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); | ||||
|   output_size->data[0] = max_time; | ||||
|   output_size->data[1] = n_batch; | ||||
|   output_size->data[2] = n_output; | ||||
|   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); | ||||
|   output_size->data[input->dims->size - 1] = n_output; | ||||
|   TF_LITE_ENSURE_OK(context, | ||||
|                     context->ResizeTensor(context, output, output_size)); | ||||
| 
 | ||||
| @ -436,6 +437,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|   const auto* params = | ||||
|       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( | ||||
|           node->builtin_data); | ||||
|   const bool time_major = params->time_major; | ||||
|   const TfLiteTensor* input = GetInput(context, node, kInputTensor); | ||||
| 
 | ||||
|   const TfLiteTensor* input_to_input_weights = | ||||
| @ -506,7 +508,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           /*aux_input_to_cell_weights=*/nullptr, | ||||
|           /*aux_input_to_output_weights=*/nullptr, input_gate_bias, | ||||
|           forget_gate_bias, cell_bias, output_gate_bias, projection_weights, | ||||
|           projection_bias, &lstm_params, /*forward_sequence=*/true, | ||||
|           projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, | ||||
|           /*output_offset=*/0, scratch_buffer, activation_state, cell_state, | ||||
|           output); | ||||
|     } | ||||
| @ -533,7 +535,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|           /*aux_input_to_cell_weights=*/nullptr, | ||||
|           /*aux_input_to_output_weights=*/nullptr, input_gate_bias, | ||||
|           forget_gate_bias, cell_bias, output_gate_bias, projection_weights, | ||||
|           projection_bias, &lstm_params, /*forward_sequence=*/true, | ||||
|           projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, | ||||
|           /*output_offset=*/0, scratch_buffer, scaling_factors, | ||||
|           prod_scaling_factors, recovered_cell_weights, input_quantized, | ||||
|           /*aux_input_quantized=*/nullptr, activation_state_quantized, | ||||
|  | ||||
| @ -32,7 +32,7 @@ using ::testing::ElementsAreArray; | ||||
| class UnidirectionalLSTMOpModel : public SingleOpModel { | ||||
|  public: | ||||
|   UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, | ||||
|                             int sequence_length, bool use_cifg, | ||||
|                             int sequence_length, bool time_major, bool use_cifg, | ||||
|                             bool use_peephole, bool use_projection_weights, | ||||
|                             bool use_projection_bias, float cell_clip, | ||||
|                             float proj_clip, | ||||
| @ -110,12 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { | ||||
| 
 | ||||
|     output_ = AddOutput(TensorType_FLOAT32); | ||||
| 
 | ||||
|     SetBuiltinOp( | ||||
|         BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, | ||||
|         BuiltinOptions_UnidirectionalSequenceLSTMOptions, | ||||
|         CreateUnidirectionalSequenceLSTMOptions( | ||||
|             builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) | ||||
|             .Union()); | ||||
|     SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, | ||||
|                  BuiltinOptions_UnidirectionalSequenceLSTMOptions, | ||||
|                  CreateUnidirectionalSequenceLSTMOptions( | ||||
|                      builder_, ActivationFunctionType_TANH, cell_clip, | ||||
|                      proj_clip, time_major) | ||||
|                      .Union()); | ||||
|     BuildInterpreter(input_shapes); | ||||
|   } | ||||
| 
 | ||||
| @ -241,12 +241,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { | ||||
|  public: | ||||
|   HybridUnidirectionalLSTMOpModel( | ||||
|       int n_batch, int n_input, int n_cell, int n_output, int sequence_length, | ||||
|       bool use_cifg, bool use_peephole, bool use_projection_weights, | ||||
|       bool use_projection_bias, float cell_clip, float proj_clip, | ||||
|       const std::vector<std::vector<int>>& input_shapes) | ||||
|       bool time_major, bool use_cifg, bool use_peephole, | ||||
|       bool use_projection_weights, bool use_projection_bias, float cell_clip, | ||||
|       float proj_clip, const std::vector<std::vector<int>>& input_shapes) | ||||
|       : UnidirectionalLSTMOpModel( | ||||
|             n_batch, n_input, n_cell, n_output, sequence_length, use_cifg, | ||||
|             use_peephole, use_projection_weights, use_projection_bias, | ||||
|             n_batch, n_input, n_cell, n_output, sequence_length, time_major, | ||||
|             use_cifg, use_peephole, use_projection_weights, use_projection_bias, | ||||
|             cell_clip, proj_clip, input_shapes, TensorType_UINT8) {} | ||||
| 
 | ||||
|   void SetInputToInputWeights(const std::vector<float>& f) { | ||||
| @ -326,21 +326,32 @@ class BaseLstmTest : public ::testing::Test { | ||||
|   // Compares output up to tolerance to the result of the lstm given the input.
 | ||||
|   void VerifyGoldens(const std::vector<std::vector<float>>& input, | ||||
|                      const std::vector<std::vector<float>>& output, | ||||
|                      UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) { | ||||
|                      UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5, | ||||
|                      bool time_major = true) { | ||||
|     const int num_batches = input.size(); | ||||
|     EXPECT_GT(num_batches, 0); | ||||
|     const int num_inputs = lstm->num_inputs(); | ||||
|     EXPECT_GT(num_inputs, 0); | ||||
|     const int input_sequence_size = input[0].size() / num_inputs; | ||||
|     EXPECT_GT(input_sequence_size, 0); | ||||
|     // Feed the whole sequence as input.
 | ||||
|     for (int i = 0; i < input_sequence_size; ++i) { | ||||
|       for (int b = 0; b < num_batches; ++b) { | ||||
|         const float* batch_start = input[b].data() + i * num_inputs; | ||||
|         const float* batch_end = batch_start + num_inputs; | ||||
|     if (time_major) { | ||||
|       // Feed the whole sequence as input.
 | ||||
|       for (int i = 0; i < input_sequence_size; ++i) { | ||||
|         for (int b = 0; b < num_batches; ++b) { | ||||
|           const float* batch_start = input[b].data() + i * num_inputs; | ||||
|           const float* batch_end = batch_start + num_inputs; | ||||
| 
 | ||||
|         lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(), | ||||
|                        batch_start, batch_end); | ||||
|           lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start, | ||||
|                          batch_end); | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       for (int b = 0; b < num_batches; ++b) { | ||||
|         const float* batch_start = input[b].data(); | ||||
|         const float* batch_end = batch_start + input_sequence_size * num_inputs; | ||||
| 
 | ||||
|         lstm->SetInput(b * input_sequence_size * num_inputs, batch_start, | ||||
|                        batch_end); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
| @ -349,15 +360,25 @@ class BaseLstmTest : public ::testing::Test { | ||||
|     const int num_outputs = lstm->num_outputs(); | ||||
|     EXPECT_GT(num_outputs, 0); | ||||
|     std::vector<float> expected; | ||||
|     for (int i = 0; i < input_sequence_size; ++i) { | ||||
|       for (int b = 0; b < num_batches; ++b) { | ||||
|         const float* golden_start_batch = output[b].data() + i * num_outputs; | ||||
|         const float* golden_end_batch = golden_start_batch + num_outputs; | ||||
| 
 | ||||
|         expected.insert(expected.end(), golden_start_batch, golden_end_batch); | ||||
|     if (time_major) { | ||||
|       for (int i = 0; i < input_sequence_size; ++i) { | ||||
|         for (int b = 0; b < num_batches; ++b) { | ||||
|           const float* golden_start_batch = output[b].data() + i * num_outputs; | ||||
|           const float* golden_end_batch = golden_start_batch + num_outputs; | ||||
| 
 | ||||
|           expected.insert(expected.end(), golden_start_batch, golden_end_batch); | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       for (int b = 0; b < num_batches; ++b) { | ||||
|         const float* golden_batch_start = output[b].data(); | ||||
|         const float* golden_batch_end = | ||||
|             golden_batch_start + input_sequence_size * num_outputs; | ||||
| 
 | ||||
|         expected.insert(expected.end(), golden_batch_start, golden_batch_end); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     EXPECT_THAT(lstm->GetOutput(), | ||||
|                 ElementsAreArray(ArrayFloatNear(expected, tolerance))); | ||||
|   } | ||||
| @ -422,7 +443,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { | ||||
| 
 | ||||
|   UnidirectionalLSTMOpModel lstm( | ||||
|       n_batch, n_input, n_cell, n_output, sequence_length, | ||||
|       /*use_cifg=*/false, /*use_peephole=*/false, | ||||
|       /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false, | ||||
|       /*use_projection_weights=*/false, | ||||
|       /*use_projection_bias=*/false, | ||||
|       /*cell_clip=*/0.0, /*proj_clip=*/0.0, | ||||
| @ -473,6 +494,73 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { | ||||
|   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); | ||||
| } | ||||
| 
 | ||||
| TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, | ||||
|        LstmBlackBoxTestBatchMajor) { | ||||
|   const int n_batch = 1; | ||||
|   const int n_input = 2; | ||||
|   // n_cell and n_output have the same size when there is no projection.
 | ||||
|   const int n_cell = 4; | ||||
|   const int n_output = 4; | ||||
|   const int sequence_length = 3; | ||||
| 
 | ||||
|   UnidirectionalLSTMOpModel lstm( | ||||
|       n_batch, n_input, n_cell, n_output, sequence_length, | ||||
|       /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false, | ||||
|       /*use_projection_weights=*/false, | ||||
|       /*use_projection_bias=*/false, | ||||
|       /*cell_clip=*/0.0, /*proj_clip=*/0.0, | ||||
|       { | ||||
|           {sequence_length, n_batch, n_input},  // input tensor
 | ||||
| 
 | ||||
|           {n_cell, n_input},  // input_to_input_weight tensor
 | ||||
|           {n_cell, n_input},  // input_to_forget_weight tensor
 | ||||
|           {n_cell, n_input},  // input_to_cell_weight tensor
 | ||||
|           {n_cell, n_input},  // input_to_output_weight tensor
 | ||||
| 
 | ||||
|           {n_cell, n_output},  // recurrent_to_input_weight tensor
 | ||||
|           {n_cell, n_output},  // recurrent_to_forget_weight tensor
 | ||||
|           {n_cell, n_output},  // recurrent_to_cell_weight tensor
 | ||||
|           {n_cell, n_output},  // recurrent_to_output_weight tensor
 | ||||
| 
 | ||||
|           {0},  // cell_to_input_weight tensor
 | ||||
|           {0},  // cell_to_forget_weight tensor
 | ||||
|           {0},  // cell_to_output_weight tensor
 | ||||
| 
 | ||||
|           {n_cell},  // input_gate_bias tensor
 | ||||
|           {n_cell},  // forget_gate_bias tensor
 | ||||
|           {n_cell},  // cell_bias tensor
 | ||||
|           {n_cell},  // output_gate_bias tensor
 | ||||
| 
 | ||||
|           {0, 0},  // projection_weight tensor
 | ||||
|           {0},     // projection_bias tensor
 | ||||
| 
 | ||||
|           {n_batch, n_output},  // activation_state tensor
 | ||||
|           {n_batch, n_cell},    // cell_state tensor
 | ||||
|       }); | ||||
| 
 | ||||
|   lstm.SetInputToInputWeights(input_to_input_weights_); | ||||
|   lstm.SetInputToCellWeights(input_to_cell_weights_); | ||||
|   lstm.SetInputToForgetWeights(input_to_forget_weights_); | ||||
|   lstm.SetInputToOutputWeights(input_to_output_weights_); | ||||
| 
 | ||||
|   lstm.SetInputGateBias(input_gate_bias_); | ||||
|   lstm.SetCellBias(cell_gate_bias_); | ||||
|   lstm.SetForgetGateBias(forget_gate_bias_); | ||||
|   lstm.SetOutputGateBias(output_gate_bias_); | ||||
| 
 | ||||
|   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); | ||||
|   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); | ||||
|   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); | ||||
|   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); | ||||
| 
 | ||||
|   // Reshuffle input and output to batch major format.
 | ||||
|   std::vector<std::vector<float>> input; | ||||
|   std::vector<std::vector<float>> output; | ||||
| 
 | ||||
|   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/1e-5, | ||||
|                 /*time_major=*/false); | ||||
| } | ||||
| 
 | ||||
| TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { | ||||
|   const int n_batch = 1; | ||||
|   const int n_input = 2; | ||||
| @ -483,7 +571,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { | ||||
| 
 | ||||
|   HybridUnidirectionalLSTMOpModel lstm( | ||||
|       n_batch, n_input, n_cell, n_output, sequence_length, | ||||
|       /*use_cifg=*/false, /*use_peephole=*/false, | ||||
|       /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false, | ||||
|       /*use_projection_weights=*/false, | ||||
|       /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, | ||||
|       { | ||||
| @ -591,7 +679,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { | ||||
| 
 | ||||
|   UnidirectionalLSTMOpModel lstm( | ||||
|       n_batch, n_input, n_cell, n_output, sequence_length, | ||||
|       /*use_cifg=*/true, /*use_peephole=*/true, | ||||
|       /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true, | ||||
|       /*use_projection_weights=*/false, | ||||
|       /*use_projection_bias=*/false, | ||||
|       /*cell_clip=*/0.0, /*proj_clip=*/0.0, | ||||
| @ -652,7 +740,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { | ||||
| 
 | ||||
|   HybridUnidirectionalLSTMOpModel lstm( | ||||
|       n_batch, n_input, n_cell, n_output, sequence_length, | ||||
|       /*use_cifg=*/true, /*use_peephole=*/true, | ||||
|       /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true, | ||||
|       /*use_projection_weights=*/false, | ||||
|       /*use_projection_bias=*/false, | ||||
|       /*cell_clip=*/0.0, /*proj_clip=*/0.0, | ||||
| @ -1311,7 +1399,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { | ||||
| 
 | ||||
|   UnidirectionalLSTMOpModel lstm( | ||||
|       n_batch, n_input, n_cell, n_output, sequence_length, | ||||
|       /*use_cifg=*/false, /*use_peephole=*/true, | ||||
|       /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, | ||||
|       /*use_projection_weights=*/true, | ||||
|       /*use_projection_bias=*/false, | ||||
|       /*cell_clip=*/0.0, /*proj_clip=*/0.0, | ||||
| @ -1377,7 +1465,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { | ||||
| 
 | ||||
|   HybridUnidirectionalLSTMOpModel lstm( | ||||
|       n_batch, n_input, n_cell, n_output, sequence_length, | ||||
|       /*use_cifg=*/false, /*use_peephole=*/true, | ||||
|       /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, | ||||
|       /*use_projection_weights=*/true, | ||||
|       /*use_projection_bias=*/false, | ||||
|       /*cell_clip=*/0.0, /*proj_clip=*/0.0, | ||||
|  | ||||
| @ -407,6 +407,9 @@ table UnidirectionalSequenceLSTMOptions { | ||||
|   fused_activation_function:ActivationFunctionType; | ||||
|   cell_clip: float; // Optional, 0.0 means no clipping | ||||
|   proj_clip: float; // Optional, 0.0 means no clipping | ||||
| 
 | ||||
|   // If true then first dimension is sequence, otherwise batch. | ||||
|   time_major:bool; | ||||
| } | ||||
| 
 | ||||
| table BidirectionalSequenceLSTMOptions { | ||||
|  | ||||
| @ -3534,10 +3534,12 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { | ||||
|   ActivationFunctionType fused_activation_function; | ||||
|   float cell_clip; | ||||
|   float proj_clip; | ||||
|   bool time_major; | ||||
|   UnidirectionalSequenceLSTMOptionsT() | ||||
|       : fused_activation_function(ActivationFunctionType_NONE), | ||||
|         cell_clip(0.0f), | ||||
|         proj_clip(0.0f) { | ||||
|         proj_clip(0.0f), | ||||
|         time_major(false) { | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| @ -3546,7 +3548,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb | ||||
|   enum { | ||||
|     VT_FUSED_ACTIVATION_FUNCTION = 4, | ||||
|     VT_CELL_CLIP = 6, | ||||
|     VT_PROJ_CLIP = 8 | ||||
|     VT_PROJ_CLIP = 8, | ||||
|     VT_TIME_MAJOR = 10 | ||||
|   }; | ||||
|   ActivationFunctionType fused_activation_function() const { | ||||
|     return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); | ||||
| @ -3557,11 +3560,15 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb | ||||
|   float proj_clip() const { | ||||
|     return GetField<float>(VT_PROJ_CLIP, 0.0f); | ||||
|   } | ||||
|   bool time_major() const { | ||||
|     return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; | ||||
|   } | ||||
|   bool Verify(flatbuffers::Verifier &verifier) const { | ||||
|     return VerifyTableStart(verifier) && | ||||
|            VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && | ||||
|            VerifyField<float>(verifier, VT_CELL_CLIP) && | ||||
|            VerifyField<float>(verifier, VT_PROJ_CLIP) && | ||||
|            VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && | ||||
|            verifier.EndTable(); | ||||
|   } | ||||
|   UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; | ||||
| @ -3581,6 +3588,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder { | ||||
|   void add_proj_clip(float proj_clip) { | ||||
|     fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); | ||||
|   } | ||||
|   void add_time_major(bool time_major) { | ||||
|     fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0); | ||||
|   } | ||||
|   explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) | ||||
|         : fbb_(_fbb) { | ||||
|     start_ = fbb_.StartTable(); | ||||
| @ -3597,10 +3607,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection | ||||
|     flatbuffers::FlatBufferBuilder &_fbb, | ||||
|     ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, | ||||
|     float cell_clip = 0.0f, | ||||
|     float proj_clip = 0.0f) { | ||||
|     float proj_clip = 0.0f, | ||||
|     bool time_major = false) { | ||||
|   UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); | ||||
|   builder_.add_proj_clip(proj_clip); | ||||
|   builder_.add_cell_clip(cell_clip); | ||||
|   builder_.add_time_major(time_major); | ||||
|   builder_.add_fused_activation_function(fused_activation_function); | ||||
|   return builder_.Finish(); | ||||
| } | ||||
| @ -8060,6 +8072,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS | ||||
|   { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; | ||||
|   { auto _e = cell_clip(); _o->cell_clip = _e; }; | ||||
|   { auto _e = proj_clip(); _o->proj_clip = _e; }; | ||||
|   { auto _e = time_major(); _o->time_major = _e; }; | ||||
| } | ||||
| 
 | ||||
| inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { | ||||
| @ -8073,11 +8086,13 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection | ||||
|   auto _fused_activation_function = _o->fused_activation_function; | ||||
|   auto _cell_clip = _o->cell_clip; | ||||
|   auto _proj_clip = _o->proj_clip; | ||||
|   auto _time_major = _o->time_major; | ||||
|   return tflite::CreateUnidirectionalSequenceLSTMOptions( | ||||
|       _fbb, | ||||
|       _fused_activation_function, | ||||
|       _cell_clip, | ||||
|       _proj_clip); | ||||
|       _proj_clip, | ||||
|       _time_major); | ||||
| } | ||||
| 
 | ||||
| inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user