From a2f7f39d982682fa8de050e001522581570c510f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Jan 2019 19:50:14 -0800 Subject: [PATCH] Add support non-stacking(cross-links) but connected to other bidi-lstm ops case, tensorflow equivalent: tf.nn.static_bidirectional_rnn PiperOrigin-RevId: 227791656 --- .../kernels/bidirectional_sequence_lstm.cc | 54 +++-- .../bidirectional_sequence_lstm_test.cc | 202 +++++++++++++++++- 2 files changed, 233 insertions(+), 23 deletions(-) diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 1ddfe7201ea..31c6e3f44c8 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -105,7 +105,10 @@ constexpr int kBwInputActivationStateTensor = 37; // Cell state tensors of size {n_batch, n_cell} constexpr int kBwInputCellStateTensor = 38; -// Auxiliary input and weights when stacking. +// Used as auxiliary input and weights when stacking for +// tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input +// to the backward cell when stacking for tf.nn.static_bidirectional_rnn case +// (without cross links). constexpr int kAuxInputTensor = 39; // Optional // Forward weights. constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional @@ -459,8 +462,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bw_aux_input_to_output_weights = GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); - const bool aux_inputs_all_or_none = - ((aux_input != nullptr) && (fw_aux_input_to_cell_weights != nullptr) && + const bool aux_inputs_weights_all_or_none = + ((fw_aux_input_to_cell_weights != nullptr) && (fw_aux_input_to_forget_weights != nullptr) && (fw_aux_input_to_output_weights != nullptr) && (bw_aux_input_to_cell_weights != nullptr) && @@ -472,8 +475,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { (bw_aux_input_to_cell_weights == nullptr) && (bw_aux_input_to_forget_weights == nullptr) && (bw_aux_input_to_output_weights == nullptr)); - TF_LITE_ENSURE(context, aux_inputs_all_or_none); - const bool has_aux_input = (aux_input != nullptr); + TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none); + + const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr); if (has_aux_input) { // Check that aux_input has the same dimensions (except last) as the input. @@ -870,6 +874,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bw_aux_input_to_output_weights = GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); + const bool has_previous_bw_output = (aux_input != nullptr); + const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr); + // Populate a TfLiteLSTMParams struct for the evaluation functions. TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, params->proj_clip, kTfLiteLSTMFullKernel}; @@ -879,6 +886,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output; const bool time_major = params->time_major; + + // We want to cover the following cases: + // + // If not stacking (not connected after other bidi lstms): + // both fw & bw will just use `input`; aux_input will be null. + // + // If stacking with cross_links, TensorFlow equivalent + // (tf.contrib.rnn.stack_bidirectional_rnn): + // both fw & bw will use `input`, but aux_input will be none null. + // Note, this time, whether connected after other bidi lstms both works. + // + // If stacking without cross_links, but connected after other bidi lstms, + // TensorFlow equivalent (tf.nn.static_bidirectional_rnn): + // fw will use `input`, bw will use aux_input, and the `real aux_input` + // will be null. + + const bool non_stacking_mode = !use_aux_input && has_previous_bw_output; + const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input; + const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input; + switch (fw_input_to_output_weights->type) { case kTfLiteFloat32: { TfLiteStatus fw_pass_status = lstm_eval::EvalFloat( @@ -891,7 +918,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*input_layer_norm_coefficients=*/nullptr, /*forget_layer_norm_coefficients=*/nullptr, /*cell_layer_norm_coefficients=*/nullptr, - /*output_layer_norm_coefficients=*/nullptr, aux_input, + /*output_layer_norm_coefficients=*/nullptr, real_aux_input, fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, @@ -902,7 +929,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = lstm_eval::EvalFloat( - input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input, bw_input_to_input_weights, bw_input_to_forget_weights, bw_input_to_cell_weights, bw_input_to_output_weights, bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, @@ -911,7 +938,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*input_layer_norm_coefficients=*/nullptr, /*forget_layer_norm_coefficients=*/nullptr, /*cell_layer_norm_coefficients=*/nullptr, - /*output_layer_norm_coefficients=*/nullptr, aux_input, + /*output_layer_norm_coefficients=*/nullptr, real_aux_input, bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, @@ -942,9 +969,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, kRecoveredCellWeights); TfLiteTensor* aux_input_quantized = - (aux_input == nullptr) - ? nullptr - : GetTemporary(context, node, kAuxInputQuantized); + use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) + : nullptr; TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( input, fw_input_to_input_weights, fw_input_to_forget_weights, @@ -956,7 +982,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*input_layer_norm_coefficients=*/nullptr, /*forget_layer_norm_coefficients=*/nullptr, /*cell_layer_norm_coefficients=*/nullptr, - /*output_layer_norm_coefficients=*/nullptr, aux_input, + /*output_layer_norm_coefficients=*/nullptr, real_aux_input, fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, @@ -970,7 +996,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( - input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input, bw_input_to_input_weights, bw_input_to_forget_weights, bw_input_to_cell_weights, bw_input_to_output_weights, bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, @@ -979,7 +1005,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*input_layer_norm_coefficients=*/nullptr, /*forget_layer_norm_coefficients=*/nullptr, /*cell_layer_norm_coefficients=*/nullptr, - /*output_layer_norm_coefficients=*/nullptr, aux_input, + /*output_layer_norm_coefficients=*/nullptr, real_aux_input, bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc index f5df6d15af7..59ea47a2a22 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -38,7 +38,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel { int sequence_length, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, bool merge_outputs, - float cell_clip, float proj_clip, + bool use_aux_input, float cell_clip, float proj_clip, bool quantize_weights, bool time_major, const std::vector>& input_shapes) : n_batch_(n_batch), @@ -185,7 +185,11 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bw_output_ = AddOutput(TensorType_FLOAT32); } - aux_input_ = AddNullInput(); + if (use_aux_input) { + aux_input_ = AddInput(TensorType_FLOAT32); + } else { + aux_input_ = AddNullInput(); + } fw_aux_input_to_input_weights_ = AddNullInput(); fw_aux_input_to_forget_weights_ = AddNullInput(); fw_aux_input_to_cell_weights_ = AddNullInput(); @@ -302,6 +306,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel { PopulateTensor(input_, offset, begin, end); } + void SetAuxInput(int offset, float* begin, float* end) { + PopulateTensor(aux_input_, offset, begin, end); + } + std::vector GetFwOutput() { return ExtractVector(fw_output_); } std::vector GetBwOutput() { return ExtractVector(bw_output_); } @@ -406,7 +414,8 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, + /*use_aux_input=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true, { {sequence_length, n_batch, n_input}, // input tensor @@ -570,7 +579,8 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*merge_outputs=*/true, /*cell_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/true, + /*use_aux_input=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true, { {sequence_length, n_batch, n_input}, // input tensor @@ -733,7 +743,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, + /*use_aux_input=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, { {sequence_length, n_batch, n_input}, // input tensor @@ -895,7 +906,8 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, /*use_peephole=*/true, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, + /*use_aux_input=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, { {sequence_length, n_batch, n_input}, // input tensor @@ -1047,7 +1059,8 @@ TEST(LSTMOpTest, BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, /*use_peephole=*/true, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, + /*use_aux_input=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, { {sequence_length, n_batch, n_input}, // input tensor @@ -1199,7 +1212,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/true, /*use_projection_weights=*/true, - /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, + /*use_aux_input=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, { {sequence_length, n_batch, n_input}, // input tensor @@ -1903,7 +1917,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/true, /*use_projection_weights=*/true, - /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, + /*use_aux_input=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/false, { {n_batch, sequence_length, n_input}, // input tensor @@ -2590,6 +2605,175 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) { EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected))); } +// Same as the no cifg no peephole no projection no clipping test, but have an +// aux input (without aux input weights), this is the case when stacking but no +// cross-links. +TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + const bool quantize_weights = GetParam(); + + BidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/false, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*merge_outputs=*/false, + /*use_aux_input=*/true, /*cell_clip=*/0.0, + /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true, + { + {sequence_length, n_batch, n_input}, // input tensor + + // Forward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + // Backward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + // TODO(b/121134029): Update tests so tensor shapes after state tensor + // are used. They are currently ignored by test_util. + {sequence_length, n_batch, n_input}, // aux_input tensor + {n_cell, 0}, // aux_fw_input_to_input tensor + {n_cell, 0}, // aux_fw_input_to_forget tensor + {n_cell, 0}, // aux_fw_input_to_cell tensor + {n_cell, 0}, // aux_fw_input_to_output tensor + {n_cell, 0}, // aux_bw_input_to_input tensor + {n_cell, 0}, // aux_bw_input_to_forget tensor + {n_cell, 0}, // aux_bw_input_to_cell tensor + {n_cell, 0}, // aux_bw_input_to_output tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + // Input should have n_input * sequence_length many values. + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_fw_golden_output[] = { + -0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + static float lstm_bw_golden_output[] = { + -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838, + 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124}; + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + // Aux input and input are the same, so we should observe the same outputs + // as there's no aux input. + lstm.SetAuxInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* fw_golden_start = lstm_fw_golden_output; + float* fw_golden_end = + fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length(); + std::vector fw_expected; + fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end); + EXPECT_THAT(lstm.GetFwOutput(), + ElementsAreArray( + ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5))); + + float* bw_golden_start = lstm_bw_golden_output; + float* bw_golden_end = + bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length(); + std::vector bw_expected; + bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); + EXPECT_THAT(lstm.GetBwOutput(), + ElementsAreArray( + ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5))); +} + } // namespace } // namespace tflite