Add support non-stacking(cross-links) but connected to other bidi-lstm ops case, tensorflow equivalent: tf.nn.static_bidirectional_rnn
PiperOrigin-RevId: 227791656
This commit is contained in:
parent
31c6f4b715
commit
a2f7f39d98
@ -105,7 +105,10 @@ constexpr int kBwInputActivationStateTensor = 37;
|
||||
// Cell state tensors of size {n_batch, n_cell}
|
||||
constexpr int kBwInputCellStateTensor = 38;
|
||||
|
||||
// Auxiliary input and weights when stacking.
|
||||
// Used as auxiliary input and weights when stacking for
|
||||
// tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
|
||||
// to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
|
||||
// (without cross links).
|
||||
constexpr int kAuxInputTensor = 39; // Optional
|
||||
// Forward weights.
|
||||
constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional
|
||||
@ -459,8 +462,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* bw_aux_input_to_output_weights =
|
||||
GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
|
||||
|
||||
const bool aux_inputs_all_or_none =
|
||||
((aux_input != nullptr) && (fw_aux_input_to_cell_weights != nullptr) &&
|
||||
const bool aux_inputs_weights_all_or_none =
|
||||
((fw_aux_input_to_cell_weights != nullptr) &&
|
||||
(fw_aux_input_to_forget_weights != nullptr) &&
|
||||
(fw_aux_input_to_output_weights != nullptr) &&
|
||||
(bw_aux_input_to_cell_weights != nullptr) &&
|
||||
@ -472,8 +475,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
(bw_aux_input_to_cell_weights == nullptr) &&
|
||||
(bw_aux_input_to_forget_weights == nullptr) &&
|
||||
(bw_aux_input_to_output_weights == nullptr));
|
||||
TF_LITE_ENSURE(context, aux_inputs_all_or_none);
|
||||
const bool has_aux_input = (aux_input != nullptr);
|
||||
TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none);
|
||||
|
||||
const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr);
|
||||
|
||||
if (has_aux_input) {
|
||||
// Check that aux_input has the same dimensions (except last) as the input.
|
||||
@ -870,6 +874,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* bw_aux_input_to_output_weights =
|
||||
GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
|
||||
|
||||
const bool has_previous_bw_output = (aux_input != nullptr);
|
||||
const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr);
|
||||
|
||||
// Populate a TfLiteLSTMParams struct for the evaluation functions.
|
||||
TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
|
||||
params->proj_clip, kTfLiteLSTMFullKernel};
|
||||
@ -879,6 +886,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
|
||||
|
||||
const bool time_major = params->time_major;
|
||||
|
||||
// We want to cover the following cases:
|
||||
//
|
||||
// If not stacking (not connected after other bidi lstms):
|
||||
// both fw & bw will just use `input`; aux_input will be null.
|
||||
//
|
||||
// If stacking with cross_links, TensorFlow equivalent
|
||||
// (tf.contrib.rnn.stack_bidirectional_rnn):
|
||||
// both fw & bw will use `input`, but aux_input will be none null.
|
||||
// Note, this time, whether connected after other bidi lstms both works.
|
||||
//
|
||||
// If stacking without cross_links, but connected after other bidi lstms,
|
||||
// TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
|
||||
// fw will use `input`, bw will use aux_input, and the `real aux_input`
|
||||
// will be null.
|
||||
|
||||
const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
|
||||
const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
|
||||
const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
|
||||
|
||||
switch (fw_input_to_output_weights->type) {
|
||||
case kTfLiteFloat32: {
|
||||
TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
|
||||
@ -891,7 +918,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
|
||||
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
|
||||
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
|
||||
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
|
||||
@ -902,7 +929,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
|
||||
TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
|
||||
input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
||||
bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
||||
bw_input_to_cell_weights, bw_input_to_output_weights,
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
||||
@ -911,7 +938,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
|
||||
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
|
||||
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
|
||||
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
|
||||
@ -942,9 +969,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, kRecoveredCellWeights);
|
||||
TfLiteTensor* aux_input_quantized =
|
||||
(aux_input == nullptr)
|
||||
? nullptr
|
||||
: GetTemporary(context, node, kAuxInputQuantized);
|
||||
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
||||
: nullptr;
|
||||
|
||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
||||
@ -956,7 +982,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
|
||||
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
|
||||
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
|
||||
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
|
||||
@ -970,7 +996,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
|
||||
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
||||
input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
||||
bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
||||
bw_input_to_cell_weights, bw_input_to_output_weights,
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
||||
@ -979,7 +1005,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
|
||||
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
|
||||
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
|
||||
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
|
||||
|
@ -38,7 +38,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
||||
int sequence_length, bool use_cifg,
|
||||
bool use_peephole, bool use_projection_weights,
|
||||
bool use_projection_bias, bool merge_outputs,
|
||||
float cell_clip, float proj_clip,
|
||||
bool use_aux_input, float cell_clip, float proj_clip,
|
||||
bool quantize_weights, bool time_major,
|
||||
const std::vector<std::vector<int>>& input_shapes)
|
||||
: n_batch_(n_batch),
|
||||
@ -185,7 +185,11 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
||||
bw_output_ = AddOutput(TensorType_FLOAT32);
|
||||
}
|
||||
|
||||
aux_input_ = AddNullInput();
|
||||
if (use_aux_input) {
|
||||
aux_input_ = AddInput(TensorType_FLOAT32);
|
||||
} else {
|
||||
aux_input_ = AddNullInput();
|
||||
}
|
||||
fw_aux_input_to_input_weights_ = AddNullInput();
|
||||
fw_aux_input_to_forget_weights_ = AddNullInput();
|
||||
fw_aux_input_to_cell_weights_ = AddNullInput();
|
||||
@ -302,6 +306,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
|
||||
PopulateTensor(input_, offset, begin, end);
|
||||
}
|
||||
|
||||
void SetAuxInput(int offset, float* begin, float* end) {
|
||||
PopulateTensor(aux_input_, offset, begin, end);
|
||||
}
|
||||
|
||||
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
|
||||
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
|
||||
|
||||
@ -406,7 +414,8 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||
/*use_peephole=*/false, /*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false,
|
||||
/*use_aux_input=*/false, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
@ -570,7 +579,8 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||
/*use_peephole=*/false, /*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/true, /*cell_clip=*/0.0,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/true,
|
||||
/*use_aux_input=*/false, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
@ -733,7 +743,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||
/*use_peephole=*/false, /*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false,
|
||||
/*use_aux_input=*/false, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
@ -895,7 +906,8 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
|
||||
/*use_peephole=*/true, /*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false,
|
||||
/*use_aux_input=*/false, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
@ -1047,7 +1059,8 @@ TEST(LSTMOpTest,
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
|
||||
/*use_peephole=*/true, /*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false,
|
||||
/*use_aux_input=*/false, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
@ -1199,7 +1212,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||
/*use_peephole=*/true, /*use_projection_weights=*/true,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false,
|
||||
/*use_aux_input=*/false, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
@ -1903,7 +1917,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) {
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||
/*use_peephole=*/true, /*use_projection_weights=*/true,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false,
|
||||
/*use_aux_input=*/false, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/false,
|
||||
{
|
||||
{n_batch, sequence_length, n_input}, // input tensor
|
||||
@ -2590,6 +2605,175 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) {
|
||||
EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected)));
|
||||
}
|
||||
|
||||
// Same as the no cifg no peephole no projection no clipping test, but have an
|
||||
// aux input (without aux input weights), this is the case when stacking but no
|
||||
// cross-links.
|
||||
TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
|
||||
const int n_batch = 1;
|
||||
const int n_input = 2;
|
||||
// n_cell and n_output have the same size when there is no projection.
|
||||
const int n_cell = 4;
|
||||
const int n_output = 4;
|
||||
const int sequence_length = 3;
|
||||
const bool quantize_weights = GetParam();
|
||||
|
||||
BidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
|
||||
/*use_peephole=*/false, /*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false, /*merge_outputs=*/false,
|
||||
/*use_aux_input=*/true, /*cell_clip=*/0.0,
|
||||
/*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
|
||||
// Forward cell
|
||||
{n_cell, n_input}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{0}, // cell_to_forget_weight tensor
|
||||
{0}, // cell_to_output_weight tensor
|
||||
|
||||
{n_cell}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
// Backward cell
|
||||
{n_cell, n_input}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{0}, // cell_to_forget_weight tensor
|
||||
{0}, // cell_to_output_weight tensor
|
||||
|
||||
{n_cell}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
// TODO(b/121134029): Update tests so tensor shapes after state tensor
|
||||
// are used. They are currently ignored by test_util.
|
||||
{sequence_length, n_batch, n_input}, // aux_input tensor
|
||||
{n_cell, 0}, // aux_fw_input_to_input tensor
|
||||
{n_cell, 0}, // aux_fw_input_to_forget tensor
|
||||
{n_cell, 0}, // aux_fw_input_to_cell tensor
|
||||
{n_cell, 0}, // aux_fw_input_to_output tensor
|
||||
{n_cell, 0}, // aux_bw_input_to_input tensor
|
||||
{n_cell, 0}, // aux_bw_input_to_forget tensor
|
||||
{n_cell, 0}, // aux_bw_input_to_cell tensor
|
||||
{n_cell, 0}, // aux_bw_input_to_output tensor
|
||||
});
|
||||
|
||||
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
|
||||
-0.34550029, 0.04266912, -0.15680569,
|
||||
-0.34856534, 0.43890524});
|
||||
|
||||
lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
|
||||
-0.20583314, 0.44344562, 0.22077113,
|
||||
-0.29909778});
|
||||
|
||||
lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
|
||||
-0.31343272, -0.40032279, 0.44781327,
|
||||
0.01387155, -0.35593212});
|
||||
|
||||
lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
|
||||
0.40525138, 0.44272184, 0.03897077, -0.1556896,
|
||||
0.19487578});
|
||||
|
||||
lstm.SetInputGateBias({0., 0., 0., 0.});
|
||||
|
||||
lstm.SetCellBias({0., 0., 0., 0.});
|
||||
|
||||
lstm.SetForgetGateBias({1., 1., 1., 1.});
|
||||
|
||||
lstm.SetOutputGateBias({0., 0., 0., 0.});
|
||||
|
||||
lstm.SetRecurrentToInputWeights(
|
||||
{-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
|
||||
-0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
|
||||
-0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
|
||||
|
||||
lstm.SetRecurrentToCellWeights(
|
||||
{-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
|
||||
-0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
|
||||
-0.46367589, 0.26016325, -0.03894562, -0.16368064});
|
||||
|
||||
lstm.SetRecurrentToForgetWeights(
|
||||
{-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
|
||||
-0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
|
||||
0.28053468, 0.01560611, -0.20127171, -0.01140004});
|
||||
|
||||
lstm.SetRecurrentToOutputWeights(
|
||||
{0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
|
||||
0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
|
||||
-0.51818722, -0.15390486, 0.0468148, 0.39922136});
|
||||
|
||||
// Input should have n_input * sequence_length many values.
|
||||
static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
|
||||
static float lstm_fw_golden_output[] = {
|
||||
-0.02973187, 0.1229473, 0.20885126, -0.15358765,
|
||||
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
|
||||
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
|
||||
static float lstm_bw_golden_output[] = {
|
||||
-0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
|
||||
0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
|
||||
|
||||
float* batch0_start = lstm_input;
|
||||
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
|
||||
|
||||
lstm.SetInput(0, batch0_start, batch0_end);
|
||||
// Aux input and input are the same, so we should observe the same outputs
|
||||
// as there's no aux input.
|
||||
lstm.SetAuxInput(0, batch0_start, batch0_end);
|
||||
|
||||
lstm.Invoke();
|
||||
|
||||
float* fw_golden_start = lstm_fw_golden_output;
|
||||
float* fw_golden_end =
|
||||
fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
|
||||
std::vector<float> fw_expected;
|
||||
fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
|
||||
EXPECT_THAT(lstm.GetFwOutput(),
|
||||
ElementsAreArray(
|
||||
ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5)));
|
||||
|
||||
float* bw_golden_start = lstm_bw_golden_output;
|
||||
float* bw_golden_end =
|
||||
bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
|
||||
std::vector<float> bw_expected;
|
||||
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
|
||||
EXPECT_THAT(lstm.GetBwOutput(),
|
||||
ElementsAreArray(
|
||||
ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user