diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h index 1e65c3cee27..5a5f3ad61c1 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data.h +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -187,10 +187,13 @@ typedef struct { } TfLiteLSTMParams; typedef struct { - // Parameters for the LSTM kernel. + // Parameters needed for the underlying LSTM. TfLiteFusedActivation activation; float cell_clip; float proj_clip; + + // If set to true then the first dimension is time, otherwise batch. + bool time_major; } TfLiteUnidirectionalSequenceLSTMParams; typedef struct { diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index 348ce54dd73..fe56c4ebf92 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -399,11 +399,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, parse_activation(seq_lstm_params->fused_activation_function()); params->cell_clip = seq_lstm_params->cell_clip(); params->proj_clip = seq_lstm_params->proj_clip(); + params->time_major = seq_lstm_params->time_major(); } *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { auto params = allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>(); diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 60abfbc85ee..f8660fbaa23 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -876,6 +876,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output; + // TODO(mirkov): add batch_major support (http://b/117326122). switch (fw_input_to_output_weights->type) { case kTfLiteFloat32: { TfLiteStatus fw_pass_status = lstm_eval::EvalFloat( @@ -889,8 +890,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_aux_input_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, fw_projection_weights, fw_projection_bias, &lstm_params, - /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, - fw_activation_state, fw_cell_state, fw_output); + /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, + fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = lstm_eval::EvalFloat( @@ -904,8 +905,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_aux_input_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, bw_projection_weights, bw_projection_bias, &lstm_params, - /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, - bw_activation_state, bw_cell_state, actual_bw_output); + /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, + bw_scratch_buffer, bw_activation_state, bw_cell_state, + actual_bw_output); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; } @@ -942,11 +944,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_aux_input_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, fw_projection_weights, fw_projection_bias, &lstm_params, - /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, - scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, aux_input_quantized, fw_activation_state_quantized, - fw_cell_state_quantized, fw_activation_state, fw_cell_state, - fw_output); + /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, + fw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, aux_input_quantized, + fw_activation_state_quantized, fw_cell_state_quantized, + fw_activation_state, fw_cell_state, fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( @@ -960,11 +962,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_aux_input_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, bw_projection_weights, bw_projection_bias, &lstm_params, - /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, - scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, aux_input_quantized, bw_activation_state_quantized, - bw_cell_state_quantized, bw_activation_state, bw_cell_state, - actual_bw_output); + /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, + bw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, aux_input_quantized, + bw_activation_state_quantized, bw_cell_state_quantized, + bw_activation_state, bw_cell_state, actual_bw_output); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index f08a1a80c05..3666122e941 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -497,6 +497,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, /*forward_sequence=*/true, + /*time_major=*/true, /*output_offset=*/0, scratch_buffer, activation_state, cell_state, output); } @@ -524,8 +525,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, /*forward_sequence=*/true, - /*output_offset=*/0, scratch_buffer, scaling_factors, - prod_scaling_factors, recovered_cell_weights, input_quantized, + /*time_major=*/true, /*output_offset=*/0, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, cell_state_quantized, activation_state, cell_state, output); } diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc index 2ef70aa933b..5b7951a9310 100644 --- a/tensorflow/contrib/lite/kernels/lstm_eval.cc +++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc @@ -710,9 +710,10 @@ TfLiteStatus EvalFloat( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, - TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { + const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, + int output_offset, TfLiteTensor* scratch_buffer, + TfLiteTensor* activation_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; const int n_batch = input->dims->data[input->dims->size - 2]; @@ -777,36 +778,71 @@ TfLiteStatus EvalFloat( aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f; } - // Loop through the sequence. const int output_batch_leading_dim = output->dims->data[output->dims->size - 1]; - const int input_step = n_batch * n_input; - const int output_step = n_batch * output_batch_leading_dim; - for (int t = 0; t < max_time; t++) { - // If this is the forward_sequence, step forward, otherwise step backwards. - const int t_rel = forward_sequence ? t : max_time - t - 1; - const float* input_ptr = input->data.f + t_rel * input_step; - if (aux_input) { - aux_input_ptr = aux_input->data.f + t_rel * input_step; - } - float* output_ptr_time = - output->data.f + t_rel * output_step + output_offset; + if (time_major) { + // Loop through the sequence. + const int input_step = n_batch * n_input; + const int output_step = n_batch * output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + if (aux_input) { + aux_input_ptr = aux_input->data.f + t_rel * input_step; + } + float* output_ptr_time = + output->data.f + t_rel * output_step + output_offset; - LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, - input_to_cell_weights->data.f, input_to_output_weights->data.f, - aux_input_ptr, aux_input_to_input_weights_ptr, - aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr, - aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr, - recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, - recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, - cell_to_forget_weights_ptr, cell_to_output_weights_ptr, - input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, - output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, - params, n_batch, n_cell, n_input, aux_input_size, n_output, - output_batch_leading_dim, activation_state->data.f, cell_state->data.f, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, output_ptr_time); + LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, aux_input_ptr, + aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, + aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, aux_input_size, n_output, + output_batch_leading_dim, activation_state->data.f, + cell_state->data.f, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_time); + } + } else { + for (int b = 0; b < n_batch; b++) { + const int input_step = n_input; + const int output_step = output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + float* output_ptr_time = + output->data.f + t_rel * output_step + output_offset; + + LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, aux_input_ptr, + aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, + aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, + projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input, + aux_input_size, n_output, output_batch_leading_dim, + activation_state->data.f, cell_state->data.f, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, + output_ptr_time); + } + } } return kTfLiteOk; } @@ -830,13 +866,13 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, - TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, - TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, - TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, - TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, - TfLiteTensor* output_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { + const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, + int output_offset, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; const int n_batch = input->dims->data[input->dims->size - 2]; @@ -990,45 +1026,90 @@ TfLiteStatus EvalHybrid( aux_input_to_output_weights->params.scale; } - // Feed the sequence into the LSTM step-by-step. const int output_batch_leading_dim = output->dims->data[output->dims->size - 1]; - const int input_step = n_batch * n_input; - const int output_step = n_batch * output_batch_leading_dim; - for (int t = 0; t < max_time; t++) { - // If this is the forward_sequence, step forward, otherwise step backwards. - const int t_rel = forward_sequence ? t : max_time - t - 1; - const float* input_ptr = input->data.f + t_rel * input_step; - if (aux_input) { - aux_input_ptr = aux_input->data.f + t_rel * input_step; - } - float* output_ptr = output->data.f + t_rel * output_step + output_offset; + if (time_major) { + // Feed the sequence into the LSTM step-by-step. + const int input_step = n_batch * n_input; + const int output_step = n_batch * output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + if (aux_input) { + aux_input_ptr = aux_input->data.f + t_rel * input_step; + } + float* output_ptr = output->data.f + t_rel * output_step + output_offset; - LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - aux_input_ptr, aux_input_to_input_weights_ptr, - aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, - aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, - aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, - aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, - recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, - recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, - recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, - recurrent_to_output_weights_scale, cell_to_input_weights_ptr, - cell_to_input_weights_scale, cell_to_forget_weights_ptr, - cell_to_forget_weights_scale, cell_to_output_weights_ptr, - cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, - n_input, aux_input_size, n_output, output_batch_leading_dim, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, - recovered_cell_weights_ptr, quantized_input_ptr, - quantized_aux_input_ptr, quantized_output_state_ptr, - quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); + LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, + forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, + projection_weights_ptr, projection_weights_scale, projection_bias_ptr, + params, n_batch, n_cell, n_input, aux_input_size, n_output, + output_batch_leading_dim, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_aux_input_ptr, + quantized_output_state_ptr, quantized_cell_state_ptr, + output_state_ptr, cell_state_ptr, output_ptr); + } + } else { + for (int b = 0; b < n_batch; b++) { + const int input_step = n_input; + const int output_step = output_batch_leading_dim; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step + // backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + float* output_ptr = + output->data.f + t_rel * output_step + output_offset; + + LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, + forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, + projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, params, n_batch, n_cell, n_input, + aux_input_size, n_output, output_batch_leading_dim, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_aux_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr); + } + } } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.h b/tensorflow/contrib/lite/kernels/lstm_eval.h index adf8cf0f645..8d8b97aead6 100644 --- a/tensorflow/contrib/lite/kernels/lstm_eval.h +++ b/tensorflow/contrib/lite/kernels/lstm_eval.h @@ -42,9 +42,10 @@ TfLiteStatus EvalFloat( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, - TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output); + const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, + int output_offset, TfLiteTensor* scratch_buffer, + TfLiteTensor* activation_state, TfLiteTensor* cell_state, + TfLiteTensor* output); TfLiteStatus EvalHybrid( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, @@ -65,12 +66,13 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, - TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, - TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, - TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, - TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, - TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output); + const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, + int output_offset, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, + TfLiteTensor* cell_state, TfLiteTensor* output); } // namespace lstm_eval } // namespace builtin diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 40029779e0e..bd6d4d1f884 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -260,8 +260,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; + const auto* params = + reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( + node->builtin_data); + const bool time_major = params->time_major; + const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; const int n_input = input->dims->data[2]; const TfLiteTensor* input_to_output_weights = @@ -296,10 +299,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); // Resize the output tensors. - TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); - output_size->data[0] = max_time; - output_size->data[1] = n_batch; - output_size->data[2] = n_output; + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); + output_size->data[input->dims->size - 1] = n_output; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); @@ -436,6 +437,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( node->builtin_data); + const bool time_major = params->time_major; const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_to_input_weights = @@ -506,7 +508,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, &lstm_params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, /*output_offset=*/0, scratch_buffer, activation_state, cell_state, output); } @@ -533,7 +535,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, &lstm_params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, /*output_offset=*/0, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc index 7b9d66c19b4..1de14dd60db 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -32,7 +32,7 @@ using ::testing::ElementsAreArray; class UnidirectionalLSTMOpModel : public SingleOpModel { public: UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, - int sequence_length, bool use_cifg, + int sequence_length, bool time_major, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, @@ -110,12 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOptions_UnidirectionalSequenceLSTMOptions, - CreateUnidirectionalSequenceLSTMOptions( - builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) - .Union()); + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_UnidirectionalSequenceLSTMOptions, + CreateUnidirectionalSequenceLSTMOptions( + builder_, ActivationFunctionType_TANH, cell_clip, + proj_clip, time_major) + .Union()); BuildInterpreter(input_shapes); } @@ -241,12 +241,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { public: HybridUnidirectionalLSTMOpModel( int n_batch, int n_input, int n_cell, int n_output, int sequence_length, - bool use_cifg, bool use_peephole, bool use_projection_weights, - bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector<std::vector<int>>& input_shapes) + bool time_major, bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, float cell_clip, + float proj_clip, const std::vector<std::vector<int>>& input_shapes) : UnidirectionalLSTMOpModel( - n_batch, n_input, n_cell, n_output, sequence_length, use_cifg, - use_peephole, use_projection_weights, use_projection_bias, + n_batch, n_input, n_cell, n_output, sequence_length, time_major, + use_cifg, use_peephole, use_projection_weights, use_projection_bias, cell_clip, proj_clip, input_shapes, TensorType_UINT8) {} void SetInputToInputWeights(const std::vector<float>& f) { @@ -326,21 +326,32 @@ class BaseLstmTest : public ::testing::Test { // Compares output up to tolerance to the result of the lstm given the input. void VerifyGoldens(const std::vector<std::vector<float>>& input, const std::vector<std::vector<float>>& output, - UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) { + UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5, + bool time_major = true) { const int num_batches = input.size(); EXPECT_GT(num_batches, 0); const int num_inputs = lstm->num_inputs(); EXPECT_GT(num_inputs, 0); const int input_sequence_size = input[0].size() / num_inputs; EXPECT_GT(input_sequence_size, 0); - // Feed the whole sequence as input. - for (int i = 0; i < input_sequence_size; ++i) { - for (int b = 0; b < num_batches; ++b) { - const float* batch_start = input[b].data() + i * num_inputs; - const float* batch_end = batch_start + num_inputs; + if (time_major) { + // Feed the whole sequence as input. + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; - lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(), - batch_start, batch_end); + lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start, + batch_end); + } + } + } else { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data(); + const float* batch_end = batch_start + input_sequence_size * num_inputs; + + lstm->SetInput(b * input_sequence_size * num_inputs, batch_start, + batch_end); } } @@ -349,15 +360,25 @@ class BaseLstmTest : public ::testing::Test { const int num_outputs = lstm->num_outputs(); EXPECT_GT(num_outputs, 0); std::vector<float> expected; - for (int i = 0; i < input_sequence_size; ++i) { - for (int b = 0; b < num_batches; ++b) { - const float* golden_start_batch = output[b].data() + i * num_outputs; - const float* golden_end_batch = golden_start_batch + num_outputs; - expected.insert(expected.end(), golden_start_batch, golden_end_batch); + if (time_major) { + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + } + } else { + for (int b = 0; b < num_batches; ++b) { + const float* golden_batch_start = output[b].data(); + const float* golden_batch_end = + golden_batch_start + input_sequence_size * num_outputs; + + expected.insert(expected.end(), golden_batch_start, golden_batch_end); } } - EXPECT_THAT(lstm->GetOutput(), ElementsAreArray(ArrayFloatNear(expected, tolerance))); } @@ -422,7 +443,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { UnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, - /*use_cifg=*/false, /*use_peephole=*/false, + /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false, /*use_projection_weights=*/false, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, @@ -473,6 +494,73 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, + LstmBlackBoxTestBatchMajor) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + UnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + // Reshuffle input and output to batch major format. + std::vector<std::vector<float>> input; + std::vector<std::vector<float>> output; + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/1e-5, + /*time_major=*/false); +} + TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; @@ -483,7 +571,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { HybridUnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, - /*use_cifg=*/false, /*use_peephole=*/false, + /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false, /*use_projection_weights=*/false, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, { @@ -591,7 +679,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { UnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, - /*use_cifg=*/true, /*use_peephole=*/true, + /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true, /*use_projection_weights=*/false, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, @@ -652,7 +740,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { HybridUnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, - /*use_cifg=*/true, /*use_peephole=*/true, + /*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true, /*use_projection_weights=*/false, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, @@ -1311,7 +1399,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { UnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, - /*use_cifg=*/false, /*use_peephole=*/true, + /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, /*use_projection_weights=*/true, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, @@ -1377,7 +1465,7 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { HybridUnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, - /*use_cifg=*/false, /*use_peephole=*/true, + /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, /*use_projection_weights=*/true, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index fe3dc56e651..3045351f222 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -407,6 +407,9 @@ table UnidirectionalSequenceLSTMOptions { fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; } table BidirectionalSequenceLSTMOptions { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 4426b7d407d..2bae6d72ece 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -3534,10 +3534,12 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; + bool time_major; UnidirectionalSequenceLSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) { + proj_clip(0.0f), + time_major(false) { } }; @@ -3546,7 +3548,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, - VT_PROJ_CLIP = 8 + VT_PROJ_CLIP = 8, + VT_TIME_MAJOR = 10 }; ActivationFunctionType fused_activation_function() const { return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -3557,11 +3560,15 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); } + bool time_major() const { + return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) && + VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && verifier.EndTable(); } UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -3581,6 +3588,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder { void add_proj_clip(float proj_clip) { fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } + void add_time_major(bool time_major) { + fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major), 0); + } explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -3597,10 +3607,12 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection flatbuffers::FlatBufferBuilder &_fbb, ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, float cell_clip = 0.0f, - float proj_clip = 0.0f) { + float proj_clip = 0.0f, + bool time_major = false) { UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_time_major(time_major); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -8060,6 +8072,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; { auto _e = cell_clip(); _o->cell_clip = _e; }; { auto _e = proj_clip(); _o->proj_clip = _e; }; + { auto _e = time_major(); _o->time_major = _e; }; } inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -8073,11 +8086,13 @@ inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirection auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; + auto _time_major = _o->time_major; return tflite::CreateUnidirectionalSequenceLSTMOptions( _fbb, _fused_activation_function, _cell_clip, - _proj_clip); + _proj_clip, + _time_major); } inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {