Add support for batch-major input in the unidirectional LSTM Op.
PiperOrigin-RevId: 217406579
This commit is contained in:
parent
e6440a80c8
commit
a3f855aca2
@ -187,10 +187,13 @@ typedef struct {
|
|||||||
} TfLiteLSTMParams;
|
} TfLiteLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
// Parameters for the LSTM kernel.
|
// Parameters needed for the underlying LSTM.
|
||||||
TfLiteFusedActivation activation;
|
TfLiteFusedActivation activation;
|
||||||
float cell_clip;
|
float cell_clip;
|
||||||
float proj_clip;
|
float proj_clip;
|
||||||
|
|
||||||
|
// If set to true then the first dimension is time, otherwise batch.
|
||||||
|
bool time_major;
|
||||||
} TfLiteUnidirectionalSequenceLSTMParams;
|
} TfLiteUnidirectionalSequenceLSTMParams;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -399,11 +399,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
parse_activation(seq_lstm_params->fused_activation_function());
|
parse_activation(seq_lstm_params->fused_activation_function());
|
||||||
params->cell_clip = seq_lstm_params->cell_clip();
|
params->cell_clip = seq_lstm_params->cell_clip();
|
||||||
params->proj_clip = seq_lstm_params->proj_clip();
|
params->proj_clip = seq_lstm_params->proj_clip();
|
||||||
|
params->time_major = seq_lstm_params->time_major();
|
||||||
}
|
}
|
||||||
*builtin_data = reinterpret_cast<void*>(params);
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
|
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
|
||||||
auto params =
|
auto params =
|
||||||
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
|
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
|
||||||
|
@ -876,6 +876,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
|
params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
|
||||||
const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
|
const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
|
||||||
|
|
||||||
|
// TODO(mirkov): add batch_major support (http://b/117326122).
|
||||||
switch (fw_input_to_output_weights->type) {
|
switch (fw_input_to_output_weights->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
|
TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
|
||||||
@ -889,8 +890,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
fw_aux_input_to_output_weights, fw_input_gate_bias,
|
fw_aux_input_to_output_weights, fw_input_gate_bias,
|
||||||
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
|
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
|
||||||
fw_projection_weights, fw_projection_bias, &lstm_params,
|
fw_projection_weights, fw_projection_bias, &lstm_params,
|
||||||
/*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
|
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||||
fw_activation_state, fw_cell_state, fw_output);
|
fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
|
||||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||||
|
|
||||||
TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
|
TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
|
||||||
@ -904,8 +905,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
||||||
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
||||||
bw_projection_weights, bw_projection_bias, &lstm_params,
|
bw_projection_weights, bw_projection_bias, &lstm_params,
|
||||||
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
|
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
|
||||||
bw_activation_state, bw_cell_state, actual_bw_output);
|
bw_scratch_buffer, bw_activation_state, bw_cell_state,
|
||||||
|
actual_bw_output);
|
||||||
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -942,11 +944,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
fw_aux_input_to_output_weights, fw_input_gate_bias,
|
fw_aux_input_to_output_weights, fw_input_gate_bias,
|
||||||
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
|
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
|
||||||
fw_projection_weights, fw_projection_bias, &lstm_params,
|
fw_projection_weights, fw_projection_bias, &lstm_params,
|
||||||
/*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
|
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||||
scaling_factors, prod_scaling_factors, recovered_cell_weights,
|
fw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||||
input_quantized, aux_input_quantized, fw_activation_state_quantized,
|
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||||
fw_cell_state_quantized, fw_activation_state, fw_cell_state,
|
fw_activation_state_quantized, fw_cell_state_quantized,
|
||||||
fw_output);
|
fw_activation_state, fw_cell_state, fw_output);
|
||||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||||
|
|
||||||
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
||||||
@ -960,11 +962,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
||||||
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
||||||
bw_projection_weights, bw_projection_bias, &lstm_params,
|
bw_projection_weights, bw_projection_bias, &lstm_params,
|
||||||
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
|
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
|
||||||
scaling_factors, prod_scaling_factors, recovered_cell_weights,
|
bw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||||
input_quantized, aux_input_quantized, bw_activation_state_quantized,
|
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||||
bw_cell_state_quantized, bw_activation_state, bw_cell_state,
|
bw_activation_state_quantized, bw_cell_state_quantized,
|
||||||
actual_bw_output);
|
bw_activation_state, bw_cell_state, actual_bw_output);
|
||||||
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
@ -497,6 +497,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||||
projection_bias, params, /*forward_sequence=*/true,
|
projection_bias, params, /*forward_sequence=*/true,
|
||||||
|
/*time_major=*/true,
|
||||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||||
output);
|
output);
|
||||||
}
|
}
|
||||||
@ -524,8 +525,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||||
projection_bias, params, /*forward_sequence=*/true,
|
projection_bias, params, /*forward_sequence=*/true,
|
||||||
/*output_offset=*/0, scratch_buffer, scaling_factors,
|
/*time_major=*/true, /*output_offset=*/0, scratch_buffer,
|
||||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
scaling_factors, prod_scaling_factors, recovered_cell_weights,
|
||||||
|
input_quantized,
|
||||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||||
cell_state_quantized, activation_state, cell_state, output);
|
cell_state_quantized, activation_state, cell_state, output);
|
||||||
}
|
}
|
||||||
|
@ -710,9 +710,10 @@ TfLiteStatus EvalFloat(
|
|||||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||||
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
|
int output_offset, TfLiteTensor* scratch_buffer,
|
||||||
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||||
|
TfLiteTensor* output) {
|
||||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||||
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
|
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
|
||||||
const int n_batch = input->dims->data[input->dims->size - 2];
|
const int n_batch = input->dims->data[input->dims->size - 2];
|
||||||
@ -777,36 +778,71 @@ TfLiteStatus EvalFloat(
|
|||||||
aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
|
aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop through the sequence.
|
|
||||||
const int output_batch_leading_dim =
|
const int output_batch_leading_dim =
|
||||||
output->dims->data[output->dims->size - 1];
|
output->dims->data[output->dims->size - 1];
|
||||||
const int input_step = n_batch * n_input;
|
if (time_major) {
|
||||||
const int output_step = n_batch * output_batch_leading_dim;
|
// Loop through the sequence.
|
||||||
for (int t = 0; t < max_time; t++) {
|
const int input_step = n_batch * n_input;
|
||||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
const int output_step = n_batch * output_batch_leading_dim;
|
||||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
for (int t = 0; t < max_time; t++) {
|
||||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
// If this is the forward_sequence, step forward, otherwise step
|
||||||
if (aux_input) {
|
// backwards.
|
||||||
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||||
}
|
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||||
float* output_ptr_time =
|
if (aux_input) {
|
||||||
output->data.f + t_rel * output_step + output_offset;
|
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
||||||
|
}
|
||||||
|
float* output_ptr_time =
|
||||||
|
output->data.f + t_rel * output_step + output_offset;
|
||||||
|
|
||||||
LstmStepWithAuxInput(
|
LstmStepWithAuxInput(
|
||||||
input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
|
input_ptr, input_to_input_weights_ptr,
|
||||||
input_to_cell_weights->data.f, input_to_output_weights->data.f,
|
input_to_forget_weights->data.f, input_to_cell_weights->data.f,
|
||||||
aux_input_ptr, aux_input_to_input_weights_ptr,
|
input_to_output_weights->data.f, aux_input_ptr,
|
||||||
aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
|
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
|
||||||
aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
|
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
|
||||||
recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
|
recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
|
||||||
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
|
recurrent_to_cell_weights->data.f,
|
||||||
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
|
||||||
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
|
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
||||||
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
|
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
|
||||||
params, n_batch, n_cell, n_input, aux_input_size, n_output,
|
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
|
||||||
output_batch_leading_dim, activation_state->data.f, cell_state->data.f,
|
params, n_batch, n_cell, n_input, aux_input_size, n_output,
|
||||||
input_gate_scratch, forget_gate_scratch, cell_scratch,
|
output_batch_leading_dim, activation_state->data.f,
|
||||||
output_gate_scratch, output_ptr_time);
|
cell_state->data.f, input_gate_scratch, forget_gate_scratch,
|
||||||
|
cell_scratch, output_gate_scratch, output_ptr_time);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int b = 0; b < n_batch; b++) {
|
||||||
|
const int input_step = n_input;
|
||||||
|
const int output_step = output_batch_leading_dim;
|
||||||
|
for (int t = 0; t < max_time; t++) {
|
||||||
|
// If this is the forward_sequence, step forward, otherwise step
|
||||||
|
// backwards.
|
||||||
|
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||||
|
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||||
|
float* output_ptr_time =
|
||||||
|
output->data.f + t_rel * output_step + output_offset;
|
||||||
|
|
||||||
|
LstmStepWithAuxInput(
|
||||||
|
input_ptr, input_to_input_weights_ptr,
|
||||||
|
input_to_forget_weights->data.f, input_to_cell_weights->data.f,
|
||||||
|
input_to_output_weights->data.f, aux_input_ptr,
|
||||||
|
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
|
||||||
|
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
|
||||||
|
recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
|
||||||
|
recurrent_to_cell_weights->data.f,
|
||||||
|
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
|
||||||
|
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
||||||
|
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
|
||||||
|
output_gate_bias->data.f, projection_weights_ptr,
|
||||||
|
projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
|
||||||
|
aux_input_size, n_output, output_batch_leading_dim,
|
||||||
|
activation_state->data.f, cell_state->data.f, input_gate_scratch,
|
||||||
|
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
||||||
|
output_ptr_time);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
@ -830,13 +866,13 @@ TfLiteStatus EvalHybrid(
|
|||||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||||
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
|
int output_offset, TfLiteTensor* scratch_buffer,
|
||||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
|
||||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||||
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||||
TfLiteTensor* output_state, TfLiteTensor* cell_state,
|
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||||
TfLiteTensor* output) {
|
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
||||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||||
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
|
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
|
||||||
const int n_batch = input->dims->data[input->dims->size - 2];
|
const int n_batch = input->dims->data[input->dims->size - 2];
|
||||||
@ -990,45 +1026,90 @@ TfLiteStatus EvalHybrid(
|
|||||||
aux_input_to_output_weights->params.scale;
|
aux_input_to_output_weights->params.scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Feed the sequence into the LSTM step-by-step.
|
|
||||||
const int output_batch_leading_dim =
|
const int output_batch_leading_dim =
|
||||||
output->dims->data[output->dims->size - 1];
|
output->dims->data[output->dims->size - 1];
|
||||||
const int input_step = n_batch * n_input;
|
if (time_major) {
|
||||||
const int output_step = n_batch * output_batch_leading_dim;
|
// Feed the sequence into the LSTM step-by-step.
|
||||||
for (int t = 0; t < max_time; t++) {
|
const int input_step = n_batch * n_input;
|
||||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
const int output_step = n_batch * output_batch_leading_dim;
|
||||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
for (int t = 0; t < max_time; t++) {
|
||||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
// If this is the forward_sequence, step forward, otherwise step
|
||||||
if (aux_input) {
|
// backwards.
|
||||||
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||||
}
|
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||||
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
|
if (aux_input) {
|
||||||
|
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
||||||
|
}
|
||||||
|
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
|
||||||
|
|
||||||
LstmStepWithAuxInput(
|
LstmStepWithAuxInput(
|
||||||
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
|
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
|
||||||
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
||||||
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
||||||
input_to_output_weights_ptr, input_to_output_weights_scale,
|
input_to_output_weights_ptr, input_to_output_weights_scale,
|
||||||
aux_input_ptr, aux_input_to_input_weights_ptr,
|
aux_input_ptr, aux_input_to_input_weights_ptr,
|
||||||
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
|
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
|
||||||
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
|
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
|
||||||
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
|
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
|
||||||
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
|
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
|
||||||
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
|
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
|
||||||
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
|
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
|
||||||
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
|
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
|
||||||
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
|
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
|
||||||
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
|
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
|
||||||
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
|
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
|
||||||
cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
|
cell_to_output_weights_scale, input_gate_bias_ptr,
|
||||||
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
|
forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
|
||||||
projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
|
projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
|
||||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
params, n_batch, n_cell, n_input, aux_input_size, n_output,
|
||||||
input_gate_scratch, forget_gate_scratch, cell_scratch,
|
output_batch_leading_dim, input_gate_scratch, forget_gate_scratch,
|
||||||
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
|
cell_scratch, output_gate_scratch, scaling_factors_ptr,
|
||||||
recovered_cell_weights_ptr, quantized_input_ptr,
|
prod_scaling_factors_ptr, recovered_cell_weights_ptr,
|
||||||
quantized_aux_input_ptr, quantized_output_state_ptr,
|
quantized_input_ptr, quantized_aux_input_ptr,
|
||||||
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
|
quantized_output_state_ptr, quantized_cell_state_ptr,
|
||||||
|
output_state_ptr, cell_state_ptr, output_ptr);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int b = 0; b < n_batch; b++) {
|
||||||
|
const int input_step = n_input;
|
||||||
|
const int output_step = output_batch_leading_dim;
|
||||||
|
for (int t = 0; t < max_time; t++) {
|
||||||
|
// If this is the forward_sequence, step forward, otherwise step
|
||||||
|
// backwards.
|
||||||
|
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||||
|
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||||
|
float* output_ptr =
|
||||||
|
output->data.f + t_rel * output_step + output_offset;
|
||||||
|
|
||||||
|
LstmStepWithAuxInput(
|
||||||
|
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
|
||||||
|
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
||||||
|
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
||||||
|
input_to_output_weights_ptr, input_to_output_weights_scale,
|
||||||
|
aux_input_ptr, aux_input_to_input_weights_ptr,
|
||||||
|
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
|
||||||
|
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
|
||||||
|
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
|
||||||
|
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
|
||||||
|
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
|
||||||
|
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
|
||||||
|
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
|
||||||
|
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
|
||||||
|
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
|
||||||
|
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
|
||||||
|
cell_to_output_weights_scale, input_gate_bias_ptr,
|
||||||
|
forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
|
||||||
|
projection_weights_ptr, projection_weights_scale,
|
||||||
|
projection_bias_ptr, params, n_batch, n_cell, n_input,
|
||||||
|
aux_input_size, n_output, output_batch_leading_dim,
|
||||||
|
input_gate_scratch, forget_gate_scratch, cell_scratch,
|
||||||
|
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
|
||||||
|
recovered_cell_weights_ptr, quantized_input_ptr,
|
||||||
|
quantized_aux_input_ptr, quantized_output_state_ptr,
|
||||||
|
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
|
||||||
|
output_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
@ -42,9 +42,10 @@ TfLiteStatus EvalFloat(
|
|||||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||||
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
|
int output_offset, TfLiteTensor* scratch_buffer,
|
||||||
TfLiteTensor* cell_state, TfLiteTensor* output);
|
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||||
|
TfLiteTensor* output);
|
||||||
|
|
||||||
TfLiteStatus EvalHybrid(
|
TfLiteStatus EvalHybrid(
|
||||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||||
@ -65,12 +66,13 @@ TfLiteStatus EvalHybrid(
|
|||||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||||
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
|
int output_offset, TfLiteTensor* scratch_buffer,
|
||||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
|
||||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||||
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
|
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||||
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output);
|
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||||
|
TfLiteTensor* cell_state, TfLiteTensor* output);
|
||||||
|
|
||||||
} // namespace lstm_eval
|
} // namespace lstm_eval
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
|
@ -260,8 +260,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
TF_LITE_ENSURE(context, input->dims->size > 1);
|
TF_LITE_ENSURE(context, input->dims->size > 1);
|
||||||
const int max_time = input->dims->data[0];
|
const auto* params =
|
||||||
const int n_batch = input->dims->data[1];
|
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
|
||||||
|
node->builtin_data);
|
||||||
|
const bool time_major = params->time_major;
|
||||||
|
const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
|
||||||
const int n_input = input->dims->data[2];
|
const int n_input = input->dims->data[2];
|
||||||
|
|
||||||
const TfLiteTensor* input_to_output_weights =
|
const TfLiteTensor* input_to_output_weights =
|
||||||
@ -296,10 +299,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
|
TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
|
||||||
|
|
||||||
// Resize the output tensors.
|
// Resize the output tensors.
|
||||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
|
||||||
output_size->data[0] = max_time;
|
output_size->data[input->dims->size - 1] = n_output;
|
||||||
output_size->data[1] = n_batch;
|
|
||||||
output_size->data[2] = n_output;
|
|
||||||
TF_LITE_ENSURE_OK(context,
|
TF_LITE_ENSURE_OK(context,
|
||||||
context->ResizeTensor(context, output, output_size));
|
context->ResizeTensor(context, output, output_size));
|
||||||
|
|
||||||
@ -436,6 +437,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const auto* params =
|
const auto* params =
|
||||||
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
|
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
|
||||||
node->builtin_data);
|
node->builtin_data);
|
||||||
|
const bool time_major = params->time_major;
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
|
||||||
const TfLiteTensor* input_to_input_weights =
|
const TfLiteTensor* input_to_input_weights =
|
||||||
@ -506,7 +508,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
/*aux_input_to_cell_weights=*/nullptr,
|
/*aux_input_to_cell_weights=*/nullptr,
|
||||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||||
projection_bias, &lstm_params, /*forward_sequence=*/true,
|
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
|
||||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||||
output);
|
output);
|
||||||
}
|
}
|
||||||
@ -533,7 +535,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
/*aux_input_to_cell_weights=*/nullptr,
|
/*aux_input_to_cell_weights=*/nullptr,
|
||||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||||
projection_bias, &lstm_params, /*forward_sequence=*/true,
|
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
|
||||||
/*output_offset=*/0, scratch_buffer, scaling_factors,
|
/*output_offset=*/0, scratch_buffer, scaling_factors,
|
||||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||||
|
@ -32,7 +32,7 @@ using ::testing::ElementsAreArray;
|
|||||||
class UnidirectionalLSTMOpModel : public SingleOpModel {
|
class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
|
UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
|
||||||
int sequence_length, bool use_cifg,
|
int sequence_length, bool time_major, bool use_cifg,
|
||||||
bool use_peephole, bool use_projection_weights,
|
bool use_peephole, bool use_projection_weights,
|
||||||
bool use_projection_bias, float cell_clip,
|
bool use_projection_bias, float cell_clip,
|
||||||
float proj_clip,
|
float proj_clip,
|
||||||
@ -110,12 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
output_ = AddOutput(TensorType_FLOAT32);
|
output_ = AddOutput(TensorType_FLOAT32);
|
||||||
|
|
||||||
SetBuiltinOp(
|
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
|
||||||
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
|
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
||||||
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
CreateUnidirectionalSequenceLSTMOptions(
|
||||||
CreateUnidirectionalSequenceLSTMOptions(
|
builder_, ActivationFunctionType_TANH, cell_clip,
|
||||||
builder_, ActivationFunctionType_TANH, cell_clip, proj_clip)
|
proj_clip, time_major)
|
||||||
.Union());
|
.Union());
|
||||||
BuildInterpreter(input_shapes);
|
BuildInterpreter(input_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,12 +241,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|||||||
public:
|
public:
|
||||||
HybridUnidirectionalLSTMOpModel(
|
HybridUnidirectionalLSTMOpModel(
|
||||||
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
|
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
|
||||||
bool use_cifg, bool use_peephole, bool use_projection_weights,
|
bool time_major, bool use_cifg, bool use_peephole,
|
||||||
bool use_projection_bias, float cell_clip, float proj_clip,
|
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
||||||
const std::vector<std::vector<int>>& input_shapes)
|
float proj_clip, const std::vector<std::vector<int>>& input_shapes)
|
||||||
: UnidirectionalLSTMOpModel(
|
: UnidirectionalLSTMOpModel(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length, use_cifg,
|
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
||||||
use_peephole, use_projection_weights, use_projection_bias,
|
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
||||||
cell_clip, proj_clip, input_shapes, TensorType_UINT8) {}
|
cell_clip, proj_clip, input_shapes, TensorType_UINT8) {}
|
||||||
|
|
||||||
void SetInputToInputWeights(const std::vector<float>& f) {
|
void SetInputToInputWeights(const std::vector<float>& f) {
|
||||||
@ -326,21 +326,32 @@ class BaseLstmTest : public ::testing::Test {
|
|||||||
// Compares output up to tolerance to the result of the lstm given the input.
|
// Compares output up to tolerance to the result of the lstm given the input.
|
||||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||||
const std::vector<std::vector<float>>& output,
|
const std::vector<std::vector<float>>& output,
|
||||||
UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
|
UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5,
|
||||||
|
bool time_major = true) {
|
||||||
const int num_batches = input.size();
|
const int num_batches = input.size();
|
||||||
EXPECT_GT(num_batches, 0);
|
EXPECT_GT(num_batches, 0);
|
||||||
const int num_inputs = lstm->num_inputs();
|
const int num_inputs = lstm->num_inputs();
|
||||||
EXPECT_GT(num_inputs, 0);
|
EXPECT_GT(num_inputs, 0);
|
||||||
const int input_sequence_size = input[0].size() / num_inputs;
|
const int input_sequence_size = input[0].size() / num_inputs;
|
||||||
EXPECT_GT(input_sequence_size, 0);
|
EXPECT_GT(input_sequence_size, 0);
|
||||||
// Feed the whole sequence as input.
|
if (time_major) {
|
||||||
for (int i = 0; i < input_sequence_size; ++i) {
|
// Feed the whole sequence as input.
|
||||||
for (int b = 0; b < num_batches; ++b) {
|
for (int i = 0; i < input_sequence_size; ++i) {
|
||||||
const float* batch_start = input[b].data() + i * num_inputs;
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
const float* batch_end = batch_start + num_inputs;
|
const float* batch_start = input[b].data() + i * num_inputs;
|
||||||
|
const float* batch_end = batch_start + num_inputs;
|
||||||
|
|
||||||
lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(),
|
lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start,
|
||||||
batch_start, batch_end);
|
batch_end);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
|
const float* batch_start = input[b].data();
|
||||||
|
const float* batch_end = batch_start + input_sequence_size * num_inputs;
|
||||||
|
|
||||||
|
lstm->SetInput(b * input_sequence_size * num_inputs, batch_start,
|
||||||
|
batch_end);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -349,15 +360,25 @@ class BaseLstmTest : public ::testing::Test {
|
|||||||
const int num_outputs = lstm->num_outputs();
|
const int num_outputs = lstm->num_outputs();
|
||||||
EXPECT_GT(num_outputs, 0);
|
EXPECT_GT(num_outputs, 0);
|
||||||
std::vector<float> expected;
|
std::vector<float> expected;
|
||||||
for (int i = 0; i < input_sequence_size; ++i) {
|
|
||||||
for (int b = 0; b < num_batches; ++b) {
|
|
||||||
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
|
||||||
const float* golden_end_batch = golden_start_batch + num_outputs;
|
|
||||||
|
|
||||||
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
if (time_major) {
|
||||||
|
for (int i = 0; i < input_sequence_size; ++i) {
|
||||||
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
|
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
||||||
|
const float* golden_end_batch = golden_start_batch + num_outputs;
|
||||||
|
|
||||||
|
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int b = 0; b < num_batches; ++b) {
|
||||||
|
const float* golden_batch_start = output[b].data();
|
||||||
|
const float* golden_batch_end =
|
||||||
|
golden_batch_start + input_sequence_size * num_outputs;
|
||||||
|
|
||||||
|
expected.insert(expected.end(), golden_batch_start, golden_batch_end);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_THAT(lstm->GetOutput(),
|
EXPECT_THAT(lstm->GetOutput(),
|
||||||
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
||||||
}
|
}
|
||||||
@ -422,7 +443,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
|
|
||||||
UnidirectionalLSTMOpModel lstm(
|
UnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*use_cifg=*/false, /*use_peephole=*/false,
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
|
||||||
/*use_projection_weights=*/false,
|
/*use_projection_weights=*/false,
|
||||||
/*use_projection_bias=*/false,
|
/*use_projection_bias=*/false,
|
||||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
@ -473,6 +494,73 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||||
|
LstmBlackBoxTestBatchMajor) {
|
||||||
|
const int n_batch = 1;
|
||||||
|
const int n_input = 2;
|
||||||
|
// n_cell and n_output have the same size when there is no projection.
|
||||||
|
const int n_cell = 4;
|
||||||
|
const int n_output = 4;
|
||||||
|
const int sequence_length = 3;
|
||||||
|
|
||||||
|
UnidirectionalLSTMOpModel lstm(
|
||||||
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
|
||||||
|
/*use_projection_weights=*/false,
|
||||||
|
/*use_projection_bias=*/false,
|
||||||
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
|
{
|
||||||
|
{sequence_length, n_batch, n_input}, // input tensor
|
||||||
|
|
||||||
|
{n_cell, n_input}, // input_to_input_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||||
|
{n_cell, n_input}, // input_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||||
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||||
|
|
||||||
|
{0}, // cell_to_input_weight tensor
|
||||||
|
{0}, // cell_to_forget_weight tensor
|
||||||
|
{0}, // cell_to_output_weight tensor
|
||||||
|
|
||||||
|
{n_cell}, // input_gate_bias tensor
|
||||||
|
{n_cell}, // forget_gate_bias tensor
|
||||||
|
{n_cell}, // cell_bias tensor
|
||||||
|
{n_cell}, // output_gate_bias tensor
|
||||||
|
|
||||||
|
{0, 0}, // projection_weight tensor
|
||||||
|
{0}, // projection_bias tensor
|
||||||
|
|
||||||
|
{n_batch, n_output}, // activation_state tensor
|
||||||
|
{n_batch, n_cell}, // cell_state tensor
|
||||||
|
});
|
||||||
|
|
||||||
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||||
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||||
|
|
||||||
|
lstm.SetInputGateBias(input_gate_bias_);
|
||||||
|
lstm.SetCellBias(cell_gate_bias_);
|
||||||
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||||
|
lstm.SetOutputGateBias(output_gate_bias_);
|
||||||
|
|
||||||
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||||
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||||
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||||
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||||
|
|
||||||
|
// Reshuffle input and output to batch major format.
|
||||||
|
std::vector<std::vector<float>> input;
|
||||||
|
std::vector<std::vector<float>> output;
|
||||||
|
|
||||||
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/1e-5,
|
||||||
|
/*time_major=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
|
||||||
const int n_batch = 1;
|
const int n_batch = 1;
|
||||||
const int n_input = 2;
|
const int n_input = 2;
|
||||||
@ -483,7 +571,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
|
|||||||
|
|
||||||
HybridUnidirectionalLSTMOpModel lstm(
|
HybridUnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*use_cifg=*/false, /*use_peephole=*/false,
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
|
||||||
/*use_projection_weights=*/false,
|
/*use_projection_weights=*/false,
|
||||||
/*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
/*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
{
|
{
|
||||||
@ -591,7 +679,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
|
|
||||||
UnidirectionalLSTMOpModel lstm(
|
UnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*use_cifg=*/true, /*use_peephole=*/true,
|
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
||||||
/*use_projection_weights=*/false,
|
/*use_projection_weights=*/false,
|
||||||
/*use_projection_bias=*/false,
|
/*use_projection_bias=*/false,
|
||||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
@ -652,7 +740,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
|
|||||||
|
|
||||||
HybridUnidirectionalLSTMOpModel lstm(
|
HybridUnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*use_cifg=*/true, /*use_peephole=*/true,
|
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
||||||
/*use_projection_weights=*/false,
|
/*use_projection_weights=*/false,
|
||||||
/*use_projection_bias=*/false,
|
/*use_projection_bias=*/false,
|
||||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
@ -1311,7 +1399,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
|
|
||||||
UnidirectionalLSTMOpModel lstm(
|
UnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*use_cifg=*/false, /*use_peephole=*/true,
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
||||||
/*use_projection_weights=*/true,
|
/*use_projection_weights=*/true,
|
||||||
/*use_projection_bias=*/false,
|
/*use_projection_bias=*/false,
|
||||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
@ -1377,7 +1465,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
|
|||||||
|
|
||||||
HybridUnidirectionalLSTMOpModel lstm(
|
HybridUnidirectionalLSTMOpModel lstm(
|
||||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||||
/*use_cifg=*/false, /*use_peephole=*/true,
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
||||||
/*use_projection_weights=*/true,
|
/*use_projection_weights=*/true,
|
||||||
/*use_projection_bias=*/false,
|
/*use_projection_bias=*/false,
|
||||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||||
|
@ -407,6 +407,9 @@ table UnidirectionalSequenceLSTMOptions {
|
|||||||
fused_activation_function:ActivationFunctionType;
|
fused_activation_function:ActivationFunctionType;
|
||||||
cell_clip: float; // Optional, 0.0 means no clipping
|
cell_clip: float; // Optional, 0.0 means no clipping
|
||||||
proj_clip: float; // Optional, 0.0 means no clipping
|
proj_clip: float; // Optional, 0.0 means no clipping
|
||||||
|
|
||||||
|
// If true then first dimension is sequence, otherwise batch.
|
||||||
|
time_major:bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
table BidirectionalSequenceLSTMOptions {
|
table BidirectionalSequenceLSTMOptions {
|
||||||
|
@ -3534,10 +3534,12 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
|
|||||||
ActivationFunctionType fused_activation_function;
|
ActivationFunctionType fused_activation_function;
|
||||||
float cell_clip;
|
float cell_clip;
|
||||||
float proj_clip;
|
float proj_clip;
|
||||||
|
bool time_major;
|
||||||
UnidirectionalSequenceLSTMOptionsT()
|
UnidirectionalSequenceLSTMOptionsT()
|
||||||
: fused_activation_function(ActivationFunctionType_NONE),
|
: fused_activation_function(ActivationFunctionType_NONE),
|
||||||
cell_clip(0.0f),
|
cell_clip(0.0f),
|
||||||
proj_clip(0.0f) {
|
proj_clip(0.0f),
|
||||||
|
time_major(false) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3546,7 +3548,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
|
|||||||
enum {
|
enum {
|
||||||
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
VT_FUSED_ACTIVATION_FUNCTION = 4,
|
||||||
VT_CELL_CLIP = 6,
|
VT_CELL_CLIP = 6,
|
||||||
VT_PROJ_CLIP = 8
|
VT_PROJ_CLIP = 8,
|
||||||
|
VT_TIME_MAJOR = 10
|
||||||
};
|
};
|
||||||
ActivationFunctionType fused_activation_function() const {
|
ActivationFunctionType fused_activation_function() const {
|
||||||
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
|
||||||
@ -3557,11 +3560,15 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb
|
|||||||
float proj_clip() const {
|
float proj_clip() const {
|
||||||
return GetField<float>(VT_PROJ_CLIP, 0.0f);
|
return GetField<float>(VT_PROJ_CLIP, 0.0f);
|
||||||
}
|
}
|
||||||
|
bool time_major() const {
|
||||||
|
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
|
||||||
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
VerifyField<float>(verifier, VT_CELL_CLIP) &&
|
||||||
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
|
||||||
|
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -3581,6 +3588,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder {
|
|||||||
void add_proj_clip(float proj_clip) {
|
void add_proj_clip(float proj_clip) {
|
||||||
fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
|
fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
|
||||||
}
|
}
|
||||||
|
void add_time_major(bool time_major) {
|
||||||
|
fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0);
|
||||||
|
}
|
||||||
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -3597,10 +3607,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
|
|||||||
flatbuffers::FlatBufferBuilder &_fbb,
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
|
ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
|
||||||
float cell_clip = 0.0f,
|
float cell_clip = 0.0f,
|
||||||
float proj_clip = 0.0f) {
|
float proj_clip = 0.0f,
|
||||||
|
bool time_major = false) {
|
||||||
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
|
||||||
builder_.add_proj_clip(proj_clip);
|
builder_.add_proj_clip(proj_clip);
|
||||||
builder_.add_cell_clip(cell_clip);
|
builder_.add_cell_clip(cell_clip);
|
||||||
|
builder_.add_time_major(time_major);
|
||||||
builder_.add_fused_activation_function(fused_activation_function);
|
builder_.add_fused_activation_function(fused_activation_function);
|
||||||
return builder_.Finish();
|
return builder_.Finish();
|
||||||
}
|
}
|
||||||
@ -8060,6 +8072,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS
|
|||||||
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
|
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
|
||||||
{ auto _e = cell_clip(); _o->cell_clip = _e; };
|
{ auto _e = cell_clip(); _o->cell_clip = _e; };
|
||||||
{ auto _e = proj_clip(); _o->proj_clip = _e; };
|
{ auto _e = proj_clip(); _o->proj_clip = _e; };
|
||||||
|
{ auto _e = time_major(); _o->time_major = _e; };
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -8073,11 +8086,13 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection
|
|||||||
auto _fused_activation_function = _o->fused_activation_function;
|
auto _fused_activation_function = _o->fused_activation_function;
|
||||||
auto _cell_clip = _o->cell_clip;
|
auto _cell_clip = _o->cell_clip;
|
||||||
auto _proj_clip = _o->proj_clip;
|
auto _proj_clip = _o->proj_clip;
|
||||||
|
auto _time_major = _o->time_major;
|
||||||
return tflite::CreateUnidirectionalSequenceLSTMOptions(
|
return tflite::CreateUnidirectionalSequenceLSTMOptions(
|
||||||
_fbb,
|
_fbb,
|
||||||
_fused_activation_function,
|
_fused_activation_function,
|
||||||
_cell_clip,
|
_cell_clip,
|
||||||
_proj_clip);
|
_proj_clip,
|
||||||
|
_time_major);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
Loading…
Reference in New Issue
Block a user