Add support for batch-major input in the unidirectional LSTM Op.

PiperOrigin-RevId: 217406579
This commit is contained in:
A. Unique TensorFlower 2018-10-16 16:14:24 -07:00 committed by TensorFlower Gardener
parent e6440a80c8
commit a3f855aca2
10 changed files with 342 additions and 144 deletions

View File

@ -187,10 +187,13 @@ typedef struct {
} TfLiteLSTMParams; } TfLiteLSTMParams;
typedef struct { typedef struct {
// Parameters for the LSTM kernel. // Parameters needed for the underlying LSTM.
TfLiteFusedActivation activation; TfLiteFusedActivation activation;
float cell_clip; float cell_clip;
float proj_clip; float proj_clip;
// If set to true then the first dimension is time, otherwise batch.
bool time_major;
} TfLiteUnidirectionalSequenceLSTMParams; } TfLiteUnidirectionalSequenceLSTMParams;
typedef struct { typedef struct {

View File

@ -399,11 +399,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
parse_activation(seq_lstm_params->fused_activation_function()); parse_activation(seq_lstm_params->fused_activation_function());
params->cell_clip = seq_lstm_params->cell_clip(); params->cell_clip = seq_lstm_params->cell_clip();
params->proj_clip = seq_lstm_params->proj_clip(); params->proj_clip = seq_lstm_params->proj_clip();
params->time_major = seq_lstm_params->time_major();
} }
*builtin_data = reinterpret_cast<void*>(params); *builtin_data = reinterpret_cast<void*>(params);
break; break;
} }
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
auto params = auto params =
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>(); allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();

View File

@ -876,6 +876,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output; const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
// TODO(mirkov): add batch_major support (http://b/117326122).
switch (fw_input_to_output_weights->type) { switch (fw_input_to_output_weights->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
TfLiteStatus fw_pass_status = lstm_eval::EvalFloat( TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
@ -889,8 +890,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_output_weights, fw_input_gate_bias, fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
fw_projection_weights, fw_projection_bias, &lstm_params, fw_projection_weights, fw_projection_bias, &lstm_params,
/*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
fw_activation_state, fw_cell_state, fw_output); fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status); TF_LITE_ENSURE_OK(context, fw_pass_status);
TfLiteStatus bw_pass_status = lstm_eval::EvalFloat( TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
@ -904,8 +905,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_to_output_weights, bw_input_gate_bias, bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
bw_projection_weights, bw_projection_bias, &lstm_params, bw_projection_weights, bw_projection_bias, &lstm_params,
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
bw_activation_state, bw_cell_state, actual_bw_output); bw_scratch_buffer, bw_activation_state, bw_cell_state,
actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status); TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk; return kTfLiteOk;
} }
@ -942,11 +944,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_output_weights, fw_input_gate_bias, fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
fw_projection_weights, fw_projection_bias, &lstm_params, fw_projection_weights, fw_projection_bias, &lstm_params,
/*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
scaling_factors, prod_scaling_factors, recovered_cell_weights, fw_scratch_buffer, scaling_factors, prod_scaling_factors,
input_quantized, aux_input_quantized, fw_activation_state_quantized, recovered_cell_weights, input_quantized, aux_input_quantized,
fw_cell_state_quantized, fw_activation_state, fw_cell_state, fw_activation_state_quantized, fw_cell_state_quantized,
fw_output); fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status); TF_LITE_ENSURE_OK(context, fw_pass_status);
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
@ -960,11 +962,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_to_output_weights, bw_input_gate_bias, bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
bw_projection_weights, bw_projection_bias, &lstm_params, bw_projection_weights, bw_projection_bias, &lstm_params,
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
scaling_factors, prod_scaling_factors, recovered_cell_weights, bw_scratch_buffer, scaling_factors, prod_scaling_factors,
input_quantized, aux_input_quantized, bw_activation_state_quantized, recovered_cell_weights, input_quantized, aux_input_quantized,
bw_cell_state_quantized, bw_activation_state, bw_cell_state, bw_activation_state_quantized, bw_cell_state_quantized,
actual_bw_output); bw_activation_state, bw_cell_state, actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status); TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk; return kTfLiteOk;
} }

View File

@ -497,6 +497,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_output_weights=*/nullptr, input_gate_bias, /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights, forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true, projection_bias, params, /*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state, /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output); output);
} }
@ -524,8 +525,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_output_weights=*/nullptr, input_gate_bias, /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights, forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true, projection_bias, params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, scaling_factors, /*time_major=*/true, /*output_offset=*/0, scratch_buffer,
prod_scaling_factors, recovered_cell_weights, input_quantized, scaling_factors, prod_scaling_factors, recovered_cell_weights,
input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, output); cell_state_quantized, activation_state, cell_state, output);
} }

View File

@ -710,9 +710,10 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* cell_state, TfLiteTensor* output) { TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
const int n_batch = input->dims->data[input->dims->size - 2]; const int n_batch = input->dims->data[input->dims->size - 2];
@ -777,36 +778,71 @@ TfLiteStatus EvalFloat(
aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f; aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
} }
// Loop through the sequence.
const int output_batch_leading_dim = const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1]; output->dims->data[output->dims->size - 1];
const int input_step = n_batch * n_input; if (time_major) {
const int output_step = n_batch * output_batch_leading_dim; // Loop through the sequence.
for (int t = 0; t < max_time; t++) { const int input_step = n_batch * n_input;
// If this is the forward_sequence, step forward, otherwise step backwards. const int output_step = n_batch * output_batch_leading_dim;
const int t_rel = forward_sequence ? t : max_time - t - 1; for (int t = 0; t < max_time; t++) {
const float* input_ptr = input->data.f + t_rel * input_step; // If this is the forward_sequence, step forward, otherwise step
if (aux_input) { // backwards.
aux_input_ptr = aux_input->data.f + t_rel * input_step; const int t_rel = forward_sequence ? t : max_time - t - 1;
} const float* input_ptr = input->data.f + t_rel * input_step;
float* output_ptr_time = if (aux_input) {
output->data.f + t_rel * output_step + output_offset; aux_input_ptr = aux_input->data.f + t_rel * input_step;
}
float* output_ptr_time =
output->data.f + t_rel * output_step + output_offset;
LstmStepWithAuxInput( LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, input_ptr, input_to_input_weights_ptr,
input_to_cell_weights->data.f, input_to_output_weights->data.f, input_to_forget_weights->data.f, input_to_cell_weights->data.f,
aux_input_ptr, aux_input_to_input_weights_ptr, input_to_output_weights->data.f, aux_input_ptr,
aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr, aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr, aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, recurrent_to_cell_weights->data.f,
cell_to_forget_weights_ptr, cell_to_output_weights_ptr, recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
params, n_batch, n_cell, n_input, aux_input_size, n_output, output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
output_batch_leading_dim, activation_state->data.f, cell_state->data.f, params, n_batch, n_cell, n_input, aux_input_size, n_output,
input_gate_scratch, forget_gate_scratch, cell_scratch, output_batch_leading_dim, activation_state->data.f,
output_gate_scratch, output_ptr_time); cell_state->data.f, input_gate_scratch, forget_gate_scratch,
cell_scratch, output_gate_scratch, output_ptr_time);
}
} else {
for (int b = 0; b < n_batch; b++) {
const int input_step = n_input;
const int output_step = output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
float* output_ptr_time =
output->data.f + t_rel * output_step + output_offset;
LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr,
input_to_forget_weights->data.f, input_to_cell_weights->data.f,
input_to_output_weights->data.f, aux_input_ptr,
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
recurrent_to_cell_weights->data.f,
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
output_gate_bias->data.f, projection_weights_ptr,
projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
aux_input_size, n_output, output_batch_leading_dim,
activation_state->data.f, cell_state->data.f, input_gate_scratch,
forget_gate_scratch, cell_scratch, output_gate_scratch,
output_ptr_time);
}
}
} }
return kTfLiteOk; return kTfLiteOk;
} }
@ -830,13 +866,13 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* output) { TfLiteTensor* cell_state, TfLiteTensor* output) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
const int n_batch = input->dims->data[input->dims->size - 2]; const int n_batch = input->dims->data[input->dims->size - 2];
@ -990,45 +1026,90 @@ TfLiteStatus EvalHybrid(
aux_input_to_output_weights->params.scale; aux_input_to_output_weights->params.scale;
} }
// Feed the sequence into the LSTM step-by-step.
const int output_batch_leading_dim = const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1]; output->dims->data[output->dims->size - 1];
const int input_step = n_batch * n_input; if (time_major) {
const int output_step = n_batch * output_batch_leading_dim; // Feed the sequence into the LSTM step-by-step.
for (int t = 0; t < max_time; t++) { const int input_step = n_batch * n_input;
// If this is the forward_sequence, step forward, otherwise step backwards. const int output_step = n_batch * output_batch_leading_dim;
const int t_rel = forward_sequence ? t : max_time - t - 1; for (int t = 0; t < max_time; t++) {
const float* input_ptr = input->data.f + t_rel * input_step; // If this is the forward_sequence, step forward, otherwise step
if (aux_input) { // backwards.
aux_input_ptr = aux_input->data.f + t_rel * input_step; const int t_rel = forward_sequence ? t : max_time - t - 1;
} const float* input_ptr = input->data.f + t_rel * input_step;
float* output_ptr = output->data.f + t_rel * output_step + output_offset; if (aux_input) {
aux_input_ptr = aux_input->data.f + t_rel * input_step;
}
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
LstmStepWithAuxInput( LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
input_to_forget_weights_ptr, input_to_forget_weights_scale, input_to_forget_weights_ptr, input_to_forget_weights_scale,
input_to_cell_weights_ptr, input_to_cell_weights_scale, input_to_cell_weights_ptr, input_to_cell_weights_scale,
input_to_output_weights_ptr, input_to_output_weights_scale, input_to_output_weights_ptr, input_to_output_weights_scale,
aux_input_ptr, aux_input_to_input_weights_ptr, aux_input_ptr, aux_input_to_input_weights_ptr,
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
recurrent_to_output_weights_scale, cell_to_input_weights_ptr, recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
cell_to_input_weights_scale, cell_to_forget_weights_ptr, cell_to_input_weights_scale, cell_to_forget_weights_ptr,
cell_to_forget_weights_scale, cell_to_output_weights_ptr, cell_to_forget_weights_scale, cell_to_output_weights_ptr,
cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, cell_to_output_weights_scale, input_gate_bias_ptr,
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
n_input, aux_input_size, n_output, output_batch_leading_dim, params, n_batch, n_cell, n_input, aux_input_size, n_output,
input_gate_scratch, forget_gate_scratch, cell_scratch, output_batch_leading_dim, input_gate_scratch, forget_gate_scratch,
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, cell_scratch, output_gate_scratch, scaling_factors_ptr,
recovered_cell_weights_ptr, quantized_input_ptr, prod_scaling_factors_ptr, recovered_cell_weights_ptr,
quantized_aux_input_ptr, quantized_output_state_ptr, quantized_input_ptr, quantized_aux_input_ptr,
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); quantized_output_state_ptr, quantized_cell_state_ptr,
output_state_ptr, cell_state_ptr, output_ptr);
}
} else {
for (int b = 0; b < n_batch; b++) {
const int input_step = n_input;
const int output_step = output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
float* output_ptr =
output->data.f + t_rel * output_step + output_offset;
LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
input_to_forget_weights_ptr, input_to_forget_weights_scale,
input_to_cell_weights_ptr, input_to_cell_weights_scale,
input_to_output_weights_ptr, input_to_output_weights_scale,
aux_input_ptr, aux_input_to_input_weights_ptr,
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
cell_to_output_weights_scale, input_gate_bias_ptr,
forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
projection_weights_ptr, projection_weights_scale,
projection_bias_ptr, params, n_batch, n_cell, n_input,
aux_input_size, n_output, output_batch_leading_dim,
input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
recovered_cell_weights_ptr, quantized_input_ptr,
quantized_aux_input_ptr, quantized_output_state_ptr,
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
output_ptr);
}
}
} }
return kTfLiteOk; return kTfLiteOk;

View File

@ -42,9 +42,10 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* cell_state, TfLiteTensor* output); TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output);
TfLiteStatus EvalHybrid( TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
@ -65,12 +66,13 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output); TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output);
} // namespace lstm_eval } // namespace lstm_eval
} // namespace builtin } // namespace builtin

View File

@ -260,8 +260,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1); TF_LITE_ENSURE(context, input->dims->size > 1);
const int max_time = input->dims->data[0]; const auto* params =
const int n_batch = input->dims->data[1]; reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data);
const bool time_major = params->time_major;
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
const int n_input = input->dims->data[2]; const int n_input = input->dims->data[2];
const TfLiteTensor* input_to_output_weights = const TfLiteTensor* input_to_output_weights =
@ -296,10 +299,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
// Resize the output tensors. // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
output_size->data[0] = max_time; output_size->data[input->dims->size - 1] = n_output;
output_size->data[1] = n_batch;
output_size->data[2] = n_output;
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size)); context->ResizeTensor(context, output, output_size));
@ -436,6 +437,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = const auto* params =
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data); node->builtin_data);
const bool time_major = params->time_major;
const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights = const TfLiteTensor* input_to_input_weights =
@ -506,7 +508,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias, /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights, forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, &lstm_params, /*forward_sequence=*/true, projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state, /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output); output);
} }
@ -533,7 +535,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias, /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights, forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, &lstm_params, /*forward_sequence=*/true, projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer, scaling_factors, /*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized, prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized,

View File

@ -32,7 +32,7 @@ using ::testing::ElementsAreArray;
class UnidirectionalLSTMOpModel : public SingleOpModel { class UnidirectionalLSTMOpModel : public SingleOpModel {
public: public:
UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
int sequence_length, bool use_cifg, int sequence_length, bool time_major, bool use_cifg,
bool use_peephole, bool use_projection_weights, bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, bool use_projection_bias, float cell_clip,
float proj_clip, float proj_clip,
@ -110,12 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
output_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp( SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_UnidirectionalSequenceLSTMOptions,
BuiltinOptions_UnidirectionalSequenceLSTMOptions, CreateUnidirectionalSequenceLSTMOptions(
CreateUnidirectionalSequenceLSTMOptions( builder_, ActivationFunctionType_TANH, cell_clip,
builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) proj_clip, time_major)
.Union()); .Union());
BuildInterpreter(input_shapes); BuildInterpreter(input_shapes);
} }
@ -241,12 +241,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
public: public:
HybridUnidirectionalLSTMOpModel( HybridUnidirectionalLSTMOpModel(
int n_batch, int n_input, int n_cell, int n_output, int sequence_length, int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
bool use_cifg, bool use_peephole, bool use_projection_weights, bool time_major, bool use_cifg, bool use_peephole,
bool use_projection_bias, float cell_clip, float proj_clip, bool use_projection_weights, bool use_projection_bias, float cell_clip,
const std::vector<std::vector<int>>& input_shapes) float proj_clip, const std::vector<std::vector<int>>& input_shapes)
: UnidirectionalLSTMOpModel( : UnidirectionalLSTMOpModel(
n_batch, n_input, n_cell, n_output, sequence_length, use_cifg, n_batch, n_input, n_cell, n_output, sequence_length, time_major,
use_peephole, use_projection_weights, use_projection_bias, use_cifg, use_peephole, use_projection_weights, use_projection_bias,
cell_clip, proj_clip, input_shapes, TensorType_UINT8) {} cell_clip, proj_clip, input_shapes, TensorType_UINT8) {}
void SetInputToInputWeights(const std::vector<float>& f) { void SetInputToInputWeights(const std::vector<float>& f) {
@ -326,21 +326,32 @@ class BaseLstmTest : public ::testing::Test {
// Compares output up to tolerance to the result of the lstm given the input. // Compares output up to tolerance to the result of the lstm given the input.
void VerifyGoldens(const std::vector<std::vector<float>>& input, void VerifyGoldens(const std::vector<std::vector<float>>& input,
const std::vector<std::vector<float>>& output, const std::vector<std::vector<float>>& output,
UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) { UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5,
bool time_major = true) {
const int num_batches = input.size(); const int num_batches = input.size();
EXPECT_GT(num_batches, 0); EXPECT_GT(num_batches, 0);
const int num_inputs = lstm->num_inputs(); const int num_inputs = lstm->num_inputs();
EXPECT_GT(num_inputs, 0); EXPECT_GT(num_inputs, 0);
const int input_sequence_size = input[0].size() / num_inputs; const int input_sequence_size = input[0].size() / num_inputs;
EXPECT_GT(input_sequence_size, 0); EXPECT_GT(input_sequence_size, 0);
// Feed the whole sequence as input. if (time_major) {
for (int i = 0; i < input_sequence_size; ++i) { // Feed the whole sequence as input.
for (int b = 0; b < num_batches; ++b) { for (int i = 0; i < input_sequence_size; ++i) {
const float* batch_start = input[b].data() + i * num_inputs; for (int b = 0; b < num_batches; ++b) {
const float* batch_end = batch_start + num_inputs; const float* batch_start = input[b].data() + i * num_inputs;
const float* batch_end = batch_start + num_inputs;
lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(), lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start,
batch_start, batch_end); batch_end);
}
}
} else {
for (int b = 0; b < num_batches; ++b) {
const float* batch_start = input[b].data();
const float* batch_end = batch_start + input_sequence_size * num_inputs;
lstm->SetInput(b * input_sequence_size * num_inputs, batch_start,
batch_end);
} }
} }
@ -349,15 +360,25 @@ class BaseLstmTest : public ::testing::Test {
const int num_outputs = lstm->num_outputs(); const int num_outputs = lstm->num_outputs();
EXPECT_GT(num_outputs, 0); EXPECT_GT(num_outputs, 0);
std::vector<float> expected; std::vector<float> expected;
for (int i = 0; i < input_sequence_size; ++i) {
for (int b = 0; b < num_batches; ++b) {
const float* golden_start_batch = output[b].data() + i * num_outputs;
const float* golden_end_batch = golden_start_batch + num_outputs;
expected.insert(expected.end(), golden_start_batch, golden_end_batch); if (time_major) {
for (int i = 0; i < input_sequence_size; ++i) {
for (int b = 0; b < num_batches; ++b) {
const float* golden_start_batch = output[b].data() + i * num_outputs;
const float* golden_end_batch = golden_start_batch + num_outputs;
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
}
}
} else {
for (int b = 0; b < num_batches; ++b) {
const float* golden_batch_start = output[b].data();
const float* golden_batch_end =
golden_batch_start + input_sequence_size * num_outputs;
expected.insert(expected.end(), golden_batch_start, golden_batch_end);
} }
} }
EXPECT_THAT(lstm->GetOutput(), EXPECT_THAT(lstm->GetOutput(),
ElementsAreArray(ArrayFloatNear(expected, tolerance))); ElementsAreArray(ArrayFloatNear(expected, tolerance)));
} }
@ -422,7 +443,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
UnidirectionalLSTMOpModel lstm( UnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*use_cifg=*/false, /*use_peephole=*/false, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
/*use_projection_weights=*/false, /*use_projection_weights=*/false,
/*use_projection_bias=*/false, /*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
@ -473,6 +494,73 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
LstmBlackBoxTestBatchMajor) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
const int n_cell = 4;
const int n_output = 4;
const int sequence_length = 3;
UnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length,
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
/*use_projection_weights=*/false,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{0}, // cell_to_input_weight tensor
{0}, // cell_to_forget_weight tensor
{0}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_);
lstm.SetInputToOutputWeights(input_to_output_weights_);
lstm.SetInputGateBias(input_gate_bias_);
lstm.SetCellBias(cell_gate_bias_);
lstm.SetForgetGateBias(forget_gate_bias_);
lstm.SetOutputGateBias(output_gate_bias_);
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
// Reshuffle input and output to batch major format.
std::vector<std::vector<float>> input;
std::vector<std::vector<float>> output;
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/1e-5,
/*time_major=*/false);
}
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
const int n_batch = 1; const int n_batch = 1;
const int n_input = 2; const int n_input = 2;
@ -483,7 +571,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
HybridUnidirectionalLSTMOpModel lstm( HybridUnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*use_cifg=*/false, /*use_peephole=*/false, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
/*use_projection_weights=*/false, /*use_projection_weights=*/false,
/*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
{ {
@ -591,7 +679,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
UnidirectionalLSTMOpModel lstm( UnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*use_cifg=*/true, /*use_peephole=*/true, /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
/*use_projection_weights=*/false, /*use_projection_weights=*/false,
/*use_projection_bias=*/false, /*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
@ -652,7 +740,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
HybridUnidirectionalLSTMOpModel lstm( HybridUnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*use_cifg=*/true, /*use_peephole=*/true, /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
/*use_projection_weights=*/false, /*use_projection_weights=*/false,
/*use_projection_bias=*/false, /*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
@ -1311,7 +1399,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
UnidirectionalLSTMOpModel lstm( UnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*use_cifg=*/false, /*use_peephole=*/true, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true, /*use_projection_weights=*/true,
/*use_projection_bias=*/false, /*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
@ -1377,7 +1465,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
HybridUnidirectionalLSTMOpModel lstm( HybridUnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, n_batch, n_input, n_cell, n_output, sequence_length,
/*use_cifg=*/false, /*use_peephole=*/true, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true, /*use_projection_weights=*/true,
/*use_projection_bias=*/false, /*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0, /*cell_clip=*/0.0, /*proj_clip=*/0.0,

View File

@ -407,6 +407,9 @@ table UnidirectionalSequenceLSTMOptions {
fused_activation_function:ActivationFunctionType; fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping
// If true then first dimension is sequence, otherwise batch.
time_major:bool;
} }
table BidirectionalSequenceLSTMOptions { table BidirectionalSequenceLSTMOptions {

View File

@ -3534,10 +3534,12 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
ActivationFunctionType fused_activation_function; ActivationFunctionType fused_activation_function;
float cell_clip; float cell_clip;
float proj_clip; float proj_clip;
bool time_major;
UnidirectionalSequenceLSTMOptionsT() UnidirectionalSequenceLSTMOptionsT()
: fused_activation_function(ActivationFunctionType_NONE), : fused_activation_function(ActivationFunctionType_NONE),
cell_clip(0.0f), cell_clip(0.0f),
proj_clip(0.0f) { proj_clip(0.0f),
time_major(false) {
} }
}; };
@ -3546,7 +3548,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
enum { enum {
VT_FUSED_ACTIVATION_FUNCTION = 4, VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6, VT_CELL_CLIP = 6,
VT_PROJ_CLIP = 8 VT_PROJ_CLIP = 8,
VT_TIME_MAJOR = 10
}; };
ActivationFunctionType fused_activation_function() const { ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@ -3557,11 +3560,15 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
float proj_clip() const { float proj_clip() const {
return GetField<float>(VT_PROJ_CLIP, 0.0f); return GetField<float>(VT_PROJ_CLIP, 0.0f);
} }
bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
verifier.EndTable(); verifier.EndTable();
} }
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -3581,6 +3588,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder {
void add_proj_clip(float proj_clip) { void add_proj_clip(float proj_clip) {
fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
} }
void add_time_major(bool time_major) {
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
}
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -3597,10 +3607,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::FlatBufferBuilder &_fbb,
ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
float cell_clip = 0.0f, float cell_clip = 0.0f,
float proj_clip = 0.0f) { float proj_clip = 0.0f,
bool time_major = false) {
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip); builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip); builder_.add_cell_clip(cell_clip);
builder_.add_time_major(time_major);
builder_.add_fused_activation_function(fused_activation_function); builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish(); return builder_.Finish();
} }
@ -8060,6 +8072,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
{ auto _e = cell_clip(); _o->cell_clip = _e; }; { auto _e = cell_clip(); _o->cell_clip = _e; };
{ auto _e = proj_clip(); _o->proj_clip = _e; }; { auto _e = proj_clip(); _o->proj_clip = _e; };
{ auto _e = time_major(); _o->time_major = _e; };
} }
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -8073,11 +8086,13 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
auto _fused_activation_function = _o->fused_activation_function; auto _fused_activation_function = _o->fused_activation_function;
auto _cell_clip = _o->cell_clip; auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip; auto _proj_clip = _o->proj_clip;
auto _time_major = _o->time_major;
return tflite::CreateUnidirectionalSequenceLSTMOptions( return tflite::CreateUnidirectionalSequenceLSTMOptions(
_fbb, _fbb,
_fused_activation_function, _fused_activation_function,
_cell_clip, _cell_clip,
_proj_clip); _proj_clip,
_time_major);
} }
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {