3358 lines
157 KiB
C++
3358 lines
157 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
// Unit test for TFLite Sequential LSTM op.
|
|
|
|
#include <vector>
|
|
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
|
#include "tensorflow/lite/kernels/test_util.h"
|
|
#include "tensorflow/lite/schema/schema_generated.h"
|
|
|
|
namespace tflite {
|
|
namespace {
|
|
|
|
using ::testing::ElementsAreArray;
|
|
|
|
class UnidirectionalLSTMOpModel : public SingleOpModel {
|
|
public:
|
|
UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
|
|
int sequence_length, bool time_major, bool use_cifg,
|
|
bool use_peephole, bool use_projection_weights,
|
|
bool use_projection_bias, float cell_clip,
|
|
float proj_clip,
|
|
const std::vector<std::vector<int>>& input_shapes,
|
|
const TensorType& weights_type = TensorType_FLOAT32,
|
|
bool is_layer_norm = false,
|
|
bool asymmetric_quantize_inputs = false)
|
|
: n_batch_(n_batch),
|
|
n_input_(n_input),
|
|
n_cell_(n_cell),
|
|
n_output_(n_output),
|
|
sequence_length_(sequence_length) {
|
|
input_ = AddInput(TensorType_FLOAT32);
|
|
|
|
if (use_cifg) {
|
|
input_to_input_weights_ = AddNullInput();
|
|
} else {
|
|
input_to_input_weights_ = AddInput(weights_type);
|
|
}
|
|
|
|
input_to_forget_weights_ = AddInput(weights_type);
|
|
input_to_cell_weights_ = AddInput(weights_type);
|
|
input_to_output_weights_ = AddInput(weights_type);
|
|
|
|
if (use_cifg) {
|
|
recurrent_to_input_weights_ = AddNullInput();
|
|
} else {
|
|
recurrent_to_input_weights_ = AddInput(weights_type);
|
|
}
|
|
|
|
recurrent_to_forget_weights_ = AddInput(weights_type);
|
|
recurrent_to_cell_weights_ = AddInput(weights_type);
|
|
recurrent_to_output_weights_ = AddInput(weights_type);
|
|
|
|
if (use_peephole) {
|
|
if (use_cifg) {
|
|
cell_to_input_weights_ = AddNullInput();
|
|
} else {
|
|
cell_to_input_weights_ = AddInput(weights_type);
|
|
}
|
|
cell_to_forget_weights_ = AddInput(weights_type);
|
|
cell_to_output_weights_ = AddInput(weights_type);
|
|
} else {
|
|
cell_to_input_weights_ = AddNullInput();
|
|
cell_to_forget_weights_ = AddNullInput();
|
|
cell_to_output_weights_ = AddNullInput();
|
|
}
|
|
|
|
if (use_cifg) {
|
|
input_gate_bias_ = AddNullInput();
|
|
} else {
|
|
input_gate_bias_ = AddInput(TensorType_FLOAT32);
|
|
}
|
|
forget_gate_bias_ = AddInput(TensorType_FLOAT32);
|
|
cell_gate_bias_ = AddInput(TensorType_FLOAT32);
|
|
output_gate_bias_ = AddInput(TensorType_FLOAT32);
|
|
|
|
if (use_projection_weights) {
|
|
projection_weights_ = AddInput(weights_type);
|
|
if (use_projection_bias) {
|
|
projection_bias_ = AddInput(TensorType_FLOAT32);
|
|
} else {
|
|
projection_bias_ = AddNullInput();
|
|
}
|
|
} else {
|
|
projection_weights_ = AddNullInput();
|
|
projection_bias_ = AddNullInput();
|
|
}
|
|
|
|
// Adding the 2 state tensors.
|
|
output_state_ = AddVariableInput(
|
|
TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}});
|
|
cell_state_ =
|
|
AddVariableInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}});
|
|
|
|
// Layer norm weights.
|
|
if (is_layer_norm) {
|
|
if (use_cifg) {
|
|
input_layer_norm_coefficients_ = AddNullInput();
|
|
} else {
|
|
input_layer_norm_coefficients_ =
|
|
AddLayerNormCoeffsTensor(20, input_shapes);
|
|
}
|
|
forget_layer_norm_coefficients_ =
|
|
AddLayerNormCoeffsTensor(21, input_shapes);
|
|
cell_layer_norm_coefficients_ =
|
|
AddLayerNormCoeffsTensor(22, input_shapes);
|
|
output_layer_norm_coefficients_ =
|
|
AddLayerNormCoeffsTensor(23, input_shapes);
|
|
}
|
|
|
|
output_ = AddOutput(TensorType_FLOAT32);
|
|
|
|
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
|
|
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
|
CreateUnidirectionalSequenceLSTMOptions(
|
|
builder_, ActivationFunctionType_TANH, cell_clip,
|
|
proj_clip, time_major, asymmetric_quantize_inputs)
|
|
.Union());
|
|
BuildInterpreter(input_shapes);
|
|
}
|
|
|
|
void SetInputToInputWeights(const std::vector<float>& f) {
|
|
PopulateTensor(input_to_input_weights_, f);
|
|
}
|
|
|
|
void SetInputToForgetWeights(const std::vector<float>& f) {
|
|
PopulateTensor(input_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetInputToCellWeights(const std::vector<float>& f) {
|
|
PopulateTensor(input_to_cell_weights_, f);
|
|
}
|
|
|
|
void SetInputToOutputWeights(const std::vector<float>& f) {
|
|
PopulateTensor(input_to_output_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToInputWeights(const std::vector<float>& f) {
|
|
PopulateTensor(recurrent_to_input_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
|
|
PopulateTensor(recurrent_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToCellWeights(const std::vector<float>& f) {
|
|
PopulateTensor(recurrent_to_cell_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
|
|
PopulateTensor(recurrent_to_output_weights_, f);
|
|
}
|
|
|
|
void SetCellToInputWeights(const std::vector<float>& f) {
|
|
PopulateTensor(cell_to_input_weights_, f);
|
|
}
|
|
|
|
void SetCellToForgetWeights(const std::vector<float>& f) {
|
|
PopulateTensor(cell_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetCellToOutputWeights(const std::vector<float>& f) {
|
|
PopulateTensor(cell_to_output_weights_, f);
|
|
}
|
|
|
|
void SetInputGateBias(const std::vector<float>& f) {
|
|
PopulateTensor(input_gate_bias_, f);
|
|
}
|
|
|
|
void SetForgetGateBias(const std::vector<float>& f) {
|
|
PopulateTensor(forget_gate_bias_, f);
|
|
}
|
|
|
|
void SetCellBias(const std::vector<float>& f) {
|
|
PopulateTensor(cell_gate_bias_, f);
|
|
}
|
|
|
|
void SetOutputGateBias(const std::vector<float>& f) {
|
|
PopulateTensor(output_gate_bias_, f);
|
|
}
|
|
|
|
void SetProjectionWeights(const std::vector<float>& f) {
|
|
PopulateTensor(projection_weights_, f);
|
|
}
|
|
|
|
void SetProjectionBias(const std::vector<float>& f) {
|
|
PopulateTensor(projection_bias_, f);
|
|
}
|
|
|
|
void SetInputLayerNormCoefficients(std::vector<float> f) {
|
|
PopulateTensor(input_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetForgetLayerNormCoefficients(std::vector<float> f) {
|
|
PopulateTensor(forget_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetCellLayerNormCoefficients(std::vector<float> f) {
|
|
PopulateTensor(cell_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetOutputLayerNormCoefficients(std::vector<float> f) {
|
|
PopulateTensor(output_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetInput(int offset, const float* begin, const float* end) {
|
|
PopulateTensor(input_, offset, const_cast<float*>(begin),
|
|
const_cast<float*>(end));
|
|
}
|
|
|
|
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
|
|
|
int num_inputs() { return n_input_; }
|
|
int num_outputs() { return n_output_; }
|
|
int num_cells() { return n_cell_; }
|
|
int num_batches() { return n_batch_; }
|
|
int sequence_length() { return sequence_length_; }
|
|
|
|
protected:
|
|
int input_;
|
|
int input_to_input_weights_;
|
|
int input_to_forget_weights_;
|
|
int input_to_cell_weights_;
|
|
int input_to_output_weights_;
|
|
|
|
int recurrent_to_input_weights_;
|
|
int recurrent_to_forget_weights_;
|
|
int recurrent_to_cell_weights_;
|
|
int recurrent_to_output_weights_;
|
|
|
|
int cell_to_input_weights_;
|
|
int cell_to_forget_weights_;
|
|
int cell_to_output_weights_;
|
|
|
|
int input_gate_bias_;
|
|
int forget_gate_bias_;
|
|
int cell_gate_bias_;
|
|
int output_gate_bias_;
|
|
|
|
int projection_weights_;
|
|
int projection_bias_;
|
|
|
|
int output_state_;
|
|
int cell_state_;
|
|
|
|
int input_layer_norm_coefficients_;
|
|
int forget_layer_norm_coefficients_;
|
|
int cell_layer_norm_coefficients_;
|
|
int output_layer_norm_coefficients_;
|
|
|
|
int output_;
|
|
|
|
int n_batch_;
|
|
int n_input_;
|
|
int n_cell_;
|
|
int n_output_;
|
|
int sequence_length_;
|
|
|
|
private:
|
|
int AddLayerNormCoeffsTensor(
|
|
int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
|
|
if (input_shapes[tensor_index][0] != 0) {
|
|
return AddInput(TensorType_FLOAT32);
|
|
} else {
|
|
return AddNullInput();
|
|
}
|
|
}
|
|
};
|
|
|
|
// The hybrid model has quantized weights.
|
|
class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|
public:
|
|
HybridUnidirectionalLSTMOpModel(
|
|
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
|
|
bool time_major, bool use_cifg, bool use_peephole,
|
|
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
|
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
|
|
TensorType tensor_type, bool asymmetric_quantize_inputs)
|
|
: UnidirectionalLSTMOpModel(
|
|
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
|
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
|
cell_clip, proj_clip, input_shapes, tensor_type, false,
|
|
asymmetric_quantize_inputs) {
|
|
tensor_type_ = tensor_type;
|
|
}
|
|
|
|
void SetWeights(int weights_idx, const std::vector<float>& f) {
|
|
if (tensor_type_ == TensorType_UINT8) {
|
|
SymmetricQuantizeAndPopulate(weights_idx, f);
|
|
} else {
|
|
SignedSymmetricQuantizeAndPopulate(weights_idx, f);
|
|
}
|
|
}
|
|
|
|
void SetInputToInputWeights(const std::vector<float>& f) {
|
|
SetWeights(input_to_input_weights_, f);
|
|
}
|
|
|
|
void SetInputToForgetWeights(const std::vector<float>& f) {
|
|
SetWeights(input_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetInputToCellWeights(const std::vector<float>& f) {
|
|
SetWeights(input_to_cell_weights_, f);
|
|
}
|
|
|
|
void SetInputToOutputWeights(const std::vector<float>& f) {
|
|
SetWeights(input_to_output_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToInputWeights(const std::vector<float>& f) {
|
|
SetWeights(recurrent_to_input_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
|
|
SetWeights(recurrent_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToCellWeights(const std::vector<float>& f) {
|
|
SetWeights(recurrent_to_cell_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
|
|
SetWeights(recurrent_to_output_weights_, f);
|
|
}
|
|
|
|
void SetCellToInputWeights(const std::vector<float>& f) {
|
|
SetWeights(cell_to_input_weights_, f);
|
|
}
|
|
|
|
void SetCellToForgetWeights(const std::vector<float>& f) {
|
|
SetWeights(cell_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetCellToOutputWeights(const std::vector<float>& f) {
|
|
SetWeights(cell_to_output_weights_, f);
|
|
}
|
|
|
|
void SetProjectionWeights(const std::vector<float>& f) {
|
|
SetWeights(projection_weights_, f);
|
|
}
|
|
|
|
protected:
|
|
TensorType tensor_type_;
|
|
};
|
|
|
|
class BaseUnidirectionalLstmTest : public ::testing::TestWithParam<bool> {
|
|
protected:
|
|
// Weights of the LSTM model. Some are optional.
|
|
std::vector<float> input_to_input_weights_;
|
|
std::vector<float> input_to_cell_weights_;
|
|
std::vector<float> input_to_forget_weights_;
|
|
std::vector<float> input_to_output_weights_;
|
|
std::vector<float> input_gate_bias_;
|
|
std::vector<float> cell_gate_bias_;
|
|
std::vector<float> forget_gate_bias_;
|
|
std::vector<float> output_gate_bias_;
|
|
std::vector<float> recurrent_to_input_weights_;
|
|
std::vector<float> recurrent_to_cell_weights_;
|
|
std::vector<float> recurrent_to_forget_weights_;
|
|
std::vector<float> recurrent_to_output_weights_;
|
|
std::vector<float> cell_to_input_weights_;
|
|
std::vector<float> cell_to_forget_weights_;
|
|
std::vector<float> cell_to_output_weights_;
|
|
std::vector<float> projection_weights_;
|
|
std::vector<float> projection_bias_;
|
|
|
|
// LSTM input is stored as num_batch x num_inputs vector.
|
|
std::vector<std::vector<float>> lstm_input_;
|
|
// LSTM output is stored as num_batch x num_outputs vector.
|
|
std::vector<std::vector<float>> lstm_golden_output_;
|
|
|
|
// Compares output up to tolerance to the result of the lstm given the input.
|
|
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
|
const std::vector<std::vector<float>>& output,
|
|
UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5,
|
|
bool time_major = true) {
|
|
const int num_batches = input.size();
|
|
EXPECT_GT(num_batches, 0);
|
|
const int num_inputs = lstm->num_inputs();
|
|
EXPECT_GT(num_inputs, 0);
|
|
const int input_sequence_size = input[0].size() / num_inputs;
|
|
EXPECT_GT(input_sequence_size, 0);
|
|
if (time_major) {
|
|
// Feed the whole sequence as input.
|
|
for (int i = 0; i < input_sequence_size; ++i) {
|
|
for (int b = 0; b < num_batches; ++b) {
|
|
const float* batch_start = input[b].data() + i * num_inputs;
|
|
const float* batch_end = batch_start + num_inputs;
|
|
|
|
lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start,
|
|
batch_end);
|
|
}
|
|
}
|
|
} else {
|
|
for (int b = 0; b < num_batches; ++b) {
|
|
const float* batch_start = input[b].data();
|
|
const float* batch_end = batch_start + input_sequence_size * num_inputs;
|
|
|
|
lstm->SetInput(b * input_sequence_size * num_inputs, batch_start,
|
|
batch_end);
|
|
}
|
|
}
|
|
|
|
lstm->Invoke();
|
|
|
|
const int num_outputs = lstm->num_outputs();
|
|
EXPECT_GT(num_outputs, 0);
|
|
std::vector<float> expected;
|
|
|
|
if (time_major) {
|
|
for (int i = 0; i < input_sequence_size; ++i) {
|
|
for (int b = 0; b < num_batches; ++b) {
|
|
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
|
const float* golden_end_batch = golden_start_batch + num_outputs;
|
|
|
|
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
|
}
|
|
}
|
|
} else {
|
|
for (int b = 0; b < num_batches; ++b) {
|
|
const float* golden_batch_start = output[b].data();
|
|
const float* golden_batch_end =
|
|
golden_batch_start + input_sequence_size * num_outputs;
|
|
|
|
expected.insert(expected.end(), golden_batch_start, golden_batch_end);
|
|
}
|
|
}
|
|
EXPECT_THAT(lstm->GetOutput(),
|
|
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
|
}
|
|
};
|
|
|
|
class NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest
|
|
: public BaseUnidirectionalLstmTest {
|
|
void SetUp() override {
|
|
input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
|
|
-0.34550029, 0.04266912, -0.15680569,
|
|
-0.34856534, 0.43890524};
|
|
input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
|
|
-0.20583314, 0.44344562, 0.22077113, -0.29909778};
|
|
input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
|
|
-0.31343272, -0.40032279, 0.44781327,
|
|
0.01387155, -0.35593212};
|
|
input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
|
|
0.40525138, 0.44272184, 0.03897077,
|
|
-0.1556896, 0.19487578};
|
|
input_gate_bias_ = {0., 0., 0., 0.};
|
|
cell_gate_bias_ = {0., 0., 0., 0.};
|
|
forget_gate_bias_ = {1., 1., 1., 1.};
|
|
output_gate_bias_ = {0., 0., 0., 0.};
|
|
|
|
recurrent_to_input_weights_ = {
|
|
-0.0063535, -0.2042388, 0.31454784, -0.35746509,
|
|
0.28902304, 0.08183324, -0.16555229, 0.02286911,
|
|
-0.13566875, 0.03034258, 0.48091322, -0.12528998,
|
|
0.24077177, -0.51332325, -0.33502164, 0.10629296};
|
|
|
|
recurrent_to_cell_weights_ = {
|
|
-0.3407414, 0.24443203, -0.2078532, 0.26320225,
|
|
0.05695659, -0.00123841, -0.4744786, -0.35869038,
|
|
-0.06418842, -0.13502428, -0.501764, 0.22830659,
|
|
-0.46367589, 0.26016325, -0.03894562, -0.16368064};
|
|
|
|
recurrent_to_forget_weights_ = {
|
|
-0.48684245, -0.06655136, 0.42224967, 0.2112639,
|
|
0.27654213, 0.20864892, -0.07646349, 0.45877004,
|
|
0.00141793, -0.14609534, 0.36447752, 0.09196436,
|
|
0.28053468, 0.01560611, -0.20127171, -0.01140004};
|
|
|
|
recurrent_to_output_weights_ = {
|
|
0.43385774, -0.17194885, 0.2718237, 0.09215671,
|
|
0.24107647, -0.39835793, 0.18212086, 0.01301402,
|
|
0.48572797, -0.50656658, 0.20047462, -0.20607421,
|
|
-0.51818722, -0.15390486, 0.0468148, 0.39922136};
|
|
|
|
lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
|
|
lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
|
|
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
|
|
-0.15053082, 0.09120187, 0.24278517, -0.12222792}};
|
|
}
|
|
};
|
|
|
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
LstmBlackBoxTest) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
UnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{0}, // cell_to_forget_weight tensor
|
|
{0}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
});
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
|
}
|
|
|
|
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
LstmBlackBoxTestBatchMajor) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
UnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{0}, // cell_to_forget_weight tensor
|
|
{0}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
});
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
// Reshuffle input and output to batch major format.
|
|
std::vector<std::vector<float>> input;
|
|
std::vector<std::vector<float>> output;
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/1e-5,
|
|
/*time_major=*/false);
|
|
}
|
|
|
|
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
HybridLstmBlackBoxTestUint8) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
HybridUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{0}, // cell_to_forget_weight tensor
|
|
{0}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
},
|
|
TensorType_UINT8, GetParam());
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
|
/*tolerance=*/0.0157651);
|
|
}
|
|
|
|
TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
HybridLstmBlackBoxTestInt8) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
HybridUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/false,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{0}, // cell_to_forget_weight tensor
|
|
{0}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
},
|
|
TensorType_INT8, GetParam());
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
|
/*tolerance=*/0.0157651);
|
|
}
|
|
|
|
class CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest
|
|
: public BaseUnidirectionalLstmTest {
|
|
void SetUp() override {
|
|
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
|
|
0.05100781, 0.04717243, 0.48944736,
|
|
-0.38535351, -0.17212132};
|
|
|
|
input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
|
|
-0.3633365, -0.22755712, 0.28253698,
|
|
0.24407166, 0.33826375};
|
|
|
|
input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
|
|
-0.09426838, -0.44257352, 0.54939759,
|
|
0.01533556, 0.42751634};
|
|
cell_gate_bias_ = {0., 0., 0., 0.};
|
|
forget_gate_bias_ = {1., 1., 1., 1.};
|
|
output_gate_bias_ = {0., 0., 0., 0.};
|
|
|
|
recurrent_to_cell_weights_ = {
|
|
0.54066205, -0.32668582, -0.43562764, -0.56094903,
|
|
0.42957711, 0.01841056, -0.32764608, -0.33027974,
|
|
-0.10826075, 0.20675004, 0.19069612, -0.03026325,
|
|
-0.54532051, 0.33003211, 0.44901288, 0.21193194};
|
|
|
|
recurrent_to_forget_weights_ = {
|
|
-0.13832897, -0.0515101, -0.2359007, -0.16661474,
|
|
-0.14340827, 0.36986142, 0.23414481, 0.55899,
|
|
0.10798943, -0.41174671, 0.17751795, -0.34484994,
|
|
-0.35874045, -0.11352962, 0.27268326, 0.54058349};
|
|
|
|
recurrent_to_output_weights_ = {
|
|
0.41613156, 0.42610586, -0.16495961, -0.5663873,
|
|
0.30579174, -0.05115908, -0.33941799, 0.23364776,
|
|
0.11178309, 0.09481031, -0.26424935, 0.46261835,
|
|
0.50248802, 0.26114327, -0.43736315, 0.33149987};
|
|
|
|
cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
|
|
0.31544167};
|
|
cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
|
|
-0.77109635};
|
|
|
|
lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
|
|
lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
|
|
-0.42312205, -0.01218222, 0.24201041, -0.08124574,
|
|
-0.358325, -0.04621704, 0.21641694, -0.06471302}};
|
|
}
|
|
};
|
|
|
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
LstmBlackBoxTest) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
UnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{0, 0}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{0, 0}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{0}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
});
|
|
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
|
}
|
|
|
|
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
HybridLstmBlackBoxTestUint8) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
HybridUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{0, 0}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
{0, 0}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{0}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
},
|
|
TensorType_UINT8, GetParam());
|
|
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
|
}
|
|
|
|
TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
HybridLstmBlackBoxTestInt8) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
HybridUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{0, 0}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{0, 0}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{0}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
},
|
|
TensorType_INT8, GetParam());
|
|
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
|
}
|
|
|
|
class NoCifgPeepholeProjectionClippingUnidirectionalLstmTest
|
|
: public BaseUnidirectionalLstmTest {
|
|
void SetUp() override {
|
|
input_to_input_weights_ = {
|
|
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
|
|
0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
|
|
-0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
|
|
-0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
|
|
-0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
|
|
-0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
|
|
-0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
|
|
0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
|
|
0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
|
|
0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
|
|
-0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
|
|
0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
|
|
-0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
|
|
-0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
|
|
-0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
|
|
0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
|
|
-0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
|
|
-0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
|
|
-0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
|
|
-0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
|
|
|
|
input_to_forget_weights_ = {
|
|
-0.0018401089, -0.004852237, 0.03698424, 0.014181704,
|
|
0.028273236, -0.016726194, -0.05249759, -0.10204261,
|
|
0.00861066, -0.040979505, -0.009899187, 0.01923892,
|
|
-0.028177269, -0.08535103, -0.14585495, 0.10662567,
|
|
-0.01909731, -0.017883534, -0.0047269356, -0.045103323,
|
|
0.0030784295, 0.076784775, 0.07463696, 0.094531395,
|
|
0.0814421, -0.12257899, -0.033945758, -0.031303465,
|
|
0.045630626, 0.06843887, -0.13492945, -0.012480007,
|
|
-0.0811829, -0.07224499, -0.09628791, 0.045100946,
|
|
0.0012300825, 0.013964662, 0.099372394, 0.02543059,
|
|
0.06958324, 0.034257296, 0.0482646, 0.06267997,
|
|
0.052625068, 0.12784666, 0.07077897, 0.025725935,
|
|
0.04165009, 0.07241905, 0.018668644, -0.037377294,
|
|
-0.06277783, -0.08833636, -0.040120605, -0.011405586,
|
|
-0.007808335, -0.010301386, -0.005102167, 0.027717464,
|
|
0.05483423, 0.11449111, 0.11289652, 0.10939839,
|
|
0.13396506, -0.08402166, -0.01901462, -0.044678304,
|
|
-0.07720565, 0.014350063, -0.11757958, -0.0652038,
|
|
-0.08185733, -0.076754324, -0.092614375, 0.10405491,
|
|
0.052960336, 0.035755895, 0.035839386, -0.012540553,
|
|
0.036881298, 0.02913376, 0.03420159, 0.05448447,
|
|
-0.054523353, 0.02582715, 0.02327355, -0.011857179,
|
|
-0.0011980024, -0.034641717, -0.026125094, -0.17582615,
|
|
-0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
|
|
-8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
|
|
|
|
input_to_cell_weights_ = {
|
|
-0.04580283, -0.09549462, -0.032418985, -0.06454633,
|
|
-0.043528453, 0.043018587, -0.049152344, -0.12418144,
|
|
-0.078985475, -0.07596889, 0.019484362, -0.11434962,
|
|
-0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
|
|
-0.025034338, -0.0028890965, 0.048929527, 0.06235075,
|
|
0.10665918, -0.032036792, -0.08505916, -0.10843358,
|
|
-0.13002433, -0.036816437, -0.02130134, -0.016518239,
|
|
0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
|
|
-0.10652836, -0.1037554, -0.13056071, -0.03266643,
|
|
-0.033702414, -0.006473424, -0.04611692, 0.014419339,
|
|
-0.025174323, 0.0396852, 0.081777506, 0.06157468,
|
|
0.10210095, -0.009658194, 0.046511717, 0.03603906,
|
|
0.0069369148, 0.015960095, -0.06507666, 0.09551598,
|
|
0.053568836, 0.06408714, 0.12835667, -0.008714329,
|
|
-0.20211966, -0.12093674, 0.029450472, 0.2849013,
|
|
-0.029227901, 0.1164364, -0.08560263, 0.09941786,
|
|
-0.036999565, -0.028842626, -0.0033637602, -0.017012902,
|
|
-0.09720865, -0.11193351, -0.029155117, -0.017936034,
|
|
-0.009768936, -0.04223324, -0.036159635, 0.06505112,
|
|
-0.021742892, -0.023377212, -0.07221364, -0.06430552,
|
|
0.05453865, 0.091149814, 0.06387331, 0.007518393,
|
|
0.055960953, 0.069779344, 0.046411168, 0.10509911,
|
|
0.07463894, 0.0075130584, 0.012850982, 0.04555431,
|
|
0.056955688, 0.06555285, 0.050801456, -0.009862683,
|
|
0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
|
|
|
|
input_to_output_weights_ = {
|
|
-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
|
|
-0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
|
|
0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
|
|
-0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
|
|
-0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
|
|
0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
|
|
-0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
|
|
-0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
|
|
-0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
|
|
-0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
|
|
0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
|
|
0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
|
|
0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
|
|
-0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
|
|
0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
|
|
0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
|
|
-0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
|
|
0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
|
|
-0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
|
|
-0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
|
|
|
|
input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
|
|
0.053110216, -0.06928846, -0.13942584, -0.11816189,
|
|
0.19483899, 0.03652339, -0.10250295, 0.036714908,
|
|
-0.18426876, 0.036065217, 0.21810818, 0.02383196,
|
|
-0.043370757, 0.08690144, -0.04444982, 0.00030581196};
|
|
|
|
forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
|
|
0.11098921, 0.15378423, 0.09263801, 0.09790885,
|
|
0.09508917, 0.061199076, 0.07665568, -0.015443159,
|
|
-0.03499149, 0.046190713, 0.08895977, 0.10899629,
|
|
0.40694186, 0.06030037, 0.012413437, -0.06108739};
|
|
|
|
cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
|
|
-0.1483596, -0.10639995, -0.091433935, 0.058573797,
|
|
-0.06809782, -0.07889636, -0.043246906, -0.09829136,
|
|
-0.4279842, 0.034901652, 0.18797937, 0.0075234566,
|
|
0.016178843, 0.1749513, 0.13975595, 0.92058027};
|
|
|
|
output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
|
|
0.027195795, 0.35373217, -0.018957434, 0.008907322,
|
|
-0.0762701, 0.12018895, 0.04216877, 0.0022856654,
|
|
0.040952638, 0.3147856, 0.08225149, -0.057416286,
|
|
-0.14995944, -0.008040261, 0.13208859, 0.029760877};
|
|
|
|
recurrent_to_input_weights_ = {
|
|
-0.001374326, -0.078856036, 0.10672688, 0.029162422,
|
|
-0.11585556, 0.02557986, -0.13446963, -0.035785314,
|
|
-0.01244275, 0.025961924, -0.02337298, -0.044228926,
|
|
-0.055839065, -0.046598054, -0.010546039, -0.06900766,
|
|
0.027239809, 0.022582639, -0.013296484, -0.05459212,
|
|
0.08981, -0.045407712, 0.08682226, -0.06867011,
|
|
-0.14390695, -0.02916037, 0.000996957, 0.091420636,
|
|
0.14283475, -0.07390571, -0.06402044, 0.062524505,
|
|
-0.093129106, 0.04860203, -0.08364217, -0.08119002,
|
|
0.009352075, 0.22920375, 0.0016303885, 0.11583097,
|
|
-0.13732095, 0.012405723, -0.07551853, 0.06343048,
|
|
0.12162708, -0.031923793, -0.014335606, 0.01790974,
|
|
-0.10650317, -0.0724401, 0.08554849, -0.05727212,
|
|
0.06556731, -0.042729504, -0.043227166, 0.011683251,
|
|
-0.013082158, -0.029302018, -0.010899579, -0.062036745,
|
|
-0.022509435, -0.00964907, -0.01567329, 0.04260106,
|
|
-0.07787477, -0.11576462, 0.017356863, 0.048673786,
|
|
-0.017577527, -0.05527947, -0.082487635, -0.040137455,
|
|
-0.10820036, -0.04666372, 0.022746278, -0.07851417,
|
|
0.01068115, 0.032956902, 0.022433773, 0.0026891115,
|
|
0.08944216, -0.0685835, 0.010513544, 0.07228705,
|
|
0.02032331, -0.059686817, -0.0005566496, -0.086984694,
|
|
0.040414046, -0.1380399, 0.094208956, -0.05722982,
|
|
0.012092817, -0.04989123, -0.086576, -0.003399834,
|
|
-0.04696032, -0.045747425, 0.10091314, 0.048676282,
|
|
-0.029037097, 0.031399418, -0.0040285117, 0.047237843,
|
|
0.09504992, 0.041799378, -0.049185462, -0.031518843,
|
|
-0.10516937, 0.026374253, 0.10058866, -0.0033195973,
|
|
-0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
|
|
-0.10167381, 0.042500053, -0.01447153, 0.06464186,
|
|
-0.017142897, 0.03312627, 0.009205989, 0.024138335,
|
|
-0.011337001, 0.035530265, -0.010912711, 0.0706555,
|
|
-0.005894094, 0.051841937, -0.1401738, -0.02351249,
|
|
0.0365468, 0.07590991, 0.08838724, 0.021681072,
|
|
-0.10086113, 0.019608743, -0.06195883, 0.077335775,
|
|
0.023646897, -0.095322326, 0.02233014, 0.09756986,
|
|
-0.048691444, -0.009579111, 0.07595467, 0.11480546,
|
|
-0.09801813, 0.019894179, 0.08502348, 0.004032281,
|
|
0.037211012, 0.068537936, -0.048005626, -0.091520436,
|
|
-0.028379958, -0.01556313, 0.06554592, -0.045599163,
|
|
-0.01672207, -0.020169014, -0.011877351, -0.20212261,
|
|
0.010889619, 0.0047078193, 0.038385306, 0.08540671,
|
|
-0.017140968, -0.0035865551, 0.016678626, 0.005633034,
|
|
0.015963363, 0.00871737, 0.060130805, 0.028611384,
|
|
0.10109069, -0.015060172, -0.07894427, 0.06401885,
|
|
0.011584063, -0.024466386, 0.0047652307, -0.09041358,
|
|
0.030737216, -0.0046374933, 0.14215417, -0.11823516,
|
|
0.019899689, 0.006106124, -0.027092824, 0.0786356,
|
|
0.05052217, -0.058925, -0.011402121, -0.024987547,
|
|
-0.0013661642, -0.06832946, -0.015667673, -0.1083353,
|
|
-0.00096863037, -0.06988685, -0.053350925, -0.027275559,
|
|
-0.033664223, -0.07978348, -0.025200296, -0.017207067,
|
|
-0.058403496, -0.055697463, 0.005798788, 0.12965427,
|
|
-0.062582195, 0.0013350133, -0.10482091, 0.0379771,
|
|
0.072521195, -0.0029455067, -0.13797039, -0.03628521,
|
|
0.013806405, -0.017858358, -0.01008298, -0.07700066,
|
|
-0.017081132, 0.019358726, 0.0027079724, 0.004635139,
|
|
0.062634714, -0.02338735, -0.039547626, -0.02050681,
|
|
0.03385117, -0.083611414, 0.002862572, -0.09421313,
|
|
0.058618143, -0.08598433, 0.00972939, 0.023867095,
|
|
-0.053934585, -0.023203006, 0.07452513, -0.048767887,
|
|
-0.07314807, -0.056307215, -0.10433547, -0.06440842,
|
|
0.04328182, 0.04389765, -0.020006588, -0.09076438,
|
|
-0.11652589, -0.021705797, 0.03345259, -0.010329105,
|
|
-0.025767034, 0.013057034, -0.07316461, -0.10145612,
|
|
0.06358255, 0.18531723, 0.07759293, 0.12006465,
|
|
0.1305557, 0.058638252, -0.03393652, 0.09622831,
|
|
-0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
|
|
-0.005644518, 0.06857898, -0.12598175, -0.035084512,
|
|
0.03156317, -0.12794146, -0.031963028, 0.04692781,
|
|
0.030070418, 0.0071660685, -0.095516115, -0.004643372,
|
|
0.040170413, -0.062104587, -0.0037324072, 0.0554317,
|
|
0.08184801, -0.019164372, 0.06791302, 0.034257166,
|
|
-0.10307039, 0.021943003, 0.046745934, 0.0790918,
|
|
-0.0265588, -0.007824208, 0.042546265, -0.00977924,
|
|
-0.0002440307, -0.017384544, -0.017990116, 0.12252321,
|
|
-0.014512694, -0.08251313, 0.08861942, 0.13589665,
|
|
0.026351685, 0.012641483, 0.07466548, 0.044301085,
|
|
-0.045414884, -0.051112458, 0.03444247, -0.08502782,
|
|
-0.04106223, -0.028126027, 0.028473156, 0.10467447};
|
|
|
|
recurrent_to_cell_weights_ = {
|
|
-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
|
|
0.055647098, -0.05713207, -0.05626563, 0.005559383,
|
|
0.03375411, -0.025757805, -0.088049285, 0.06017052,
|
|
-0.06570978, 0.007384076, 0.035123326, -0.07920549,
|
|
0.053676967, 0.044480428, -0.07663568, 0.0071805613,
|
|
0.08089997, 0.05143358, 0.038261272, 0.03339287,
|
|
-0.027673481, 0.044746667, 0.028349208, 0.020090483,
|
|
-0.019443132, -0.030755889, -0.0040000007, 0.04465846,
|
|
-0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
|
|
-0.10893326, 0.076739706, -0.08509834, -0.027997585,
|
|
0.037871376, 0.01449768, -0.09002357, -0.06111149,
|
|
-0.046195522, 0.0422062, -0.005683705, -0.1253618,
|
|
-0.012925729, -0.04890792, 0.06985068, 0.037654128,
|
|
0.03398274, -0.004781977, 0.007032333, -0.031787455,
|
|
0.010868644, -0.031489216, 0.09525667, 0.013939797,
|
|
0.0058680447, 0.0167067, 0.02668468, -0.04797466,
|
|
-0.048885044, -0.12722108, 0.035304096, 0.06554885,
|
|
0.00972396, -0.039238118, -0.05159735, -0.11329045,
|
|
0.1613692, -0.03750952, 0.06529313, -0.071974665,
|
|
-0.11769596, 0.015524369, -0.0013754242, -0.12446318,
|
|
0.02786344, -0.014179351, 0.005264273, 0.14376344,
|
|
0.015983658, 0.03406988, -0.06939408, 0.040699873,
|
|
0.02111075, 0.09669095, 0.041345075, -0.08316494,
|
|
-0.07684199, -0.045768797, 0.032298047, -0.041805092,
|
|
0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
|
|
-0.024950314, 0.11574242, 0.04508852, -0.04335324,
|
|
0.06760663, -0.027437469, 0.07216407, 0.06977076,
|
|
-0.05438599, 0.034033038, -0.028602652, 0.05346137,
|
|
0.043184172, -0.037189785, 0.10420091, 0.00882477,
|
|
-0.054019816, -0.074273005, -0.030617684, -0.0028467078,
|
|
0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
|
|
0.04361412, -0.007001822, 0.09631092, -0.06702025,
|
|
-0.042049985, -0.035070654, -0.04103342, -0.10273396,
|
|
0.0544271, 0.037184782, -0.13150354, -0.0058036847,
|
|
-0.008264958, 0.042035464, 0.05891794, 0.029673764,
|
|
0.0063542654, 0.044788733, 0.054816857, 0.062257513,
|
|
-0.00093483756, 0.048938446, -0.004952862, -0.007730018,
|
|
-0.04043371, -0.017094059, 0.07229206, -0.023670016,
|
|
-0.052195564, -0.025616996, -0.01520939, 0.045104615,
|
|
-0.007376126, 0.003533447, 0.006570588, 0.056037236,
|
|
0.12436656, 0.051817212, 0.028532185, -0.08686856,
|
|
0.11868599, 0.07663395, -0.07323171, 0.03463402,
|
|
-0.050708205, -0.04458982, -0.11590894, 0.021273347,
|
|
0.1251325, -0.15313013, -0.12224372, 0.17228661,
|
|
0.023029093, 0.086124025, 0.006445803, -0.03496501,
|
|
0.028332196, 0.04449512, -0.042436164, -0.026587414,
|
|
-0.006041347, -0.09292539, -0.05678812, 0.03897832,
|
|
0.09465633, 0.008115513, -0.02171956, 0.08304309,
|
|
0.071401566, 0.019622514, 0.032163795, -0.004167056,
|
|
0.02295182, 0.030739572, 0.056506045, 0.004612461,
|
|
0.06524936, 0.059999723, 0.046395954, -0.0045512207,
|
|
-0.1335546, -0.030136576, 0.11584653, -0.014678886,
|
|
0.0020118146, -0.09688814, -0.0790206, 0.039770417,
|
|
-0.0329582, 0.07922767, 0.029322514, 0.026405897,
|
|
0.04207835, -0.07073373, 0.063781224, 0.0859677,
|
|
-0.10925287, -0.07011058, 0.048005477, 0.03438226,
|
|
-0.09606514, -0.006669445, -0.043381985, 0.04240257,
|
|
-0.06955775, -0.06769346, 0.043903265, -0.026784198,
|
|
-0.017840602, 0.024307009, -0.040079936, -0.019946516,
|
|
0.045318738, -0.12233574, 0.026170589, 0.0074471775,
|
|
0.15978073, 0.10185836, 0.10298046, -0.015476589,
|
|
-0.039390966, -0.072174534, 0.0739445, -0.1211869,
|
|
-0.0347889, -0.07943156, 0.014809798, -0.12412325,
|
|
-0.0030663363, 0.039695457, 0.0647603, -0.08291318,
|
|
-0.018529687, -0.004423833, 0.0037507233, 0.084633216,
|
|
-0.01514876, -0.056505352, -0.012800942, -0.06994386,
|
|
0.012962922, -0.031234352, 0.07029052, 0.016418684,
|
|
0.03618972, 0.055686004, -0.08663945, -0.017404709,
|
|
-0.054761406, 0.029065743, 0.052404847, 0.020238016,
|
|
0.0048197987, -0.0214882, 0.07078733, 0.013016777,
|
|
0.06262858, 0.009184685, 0.020785125, -0.043904778,
|
|
-0.0270329, -0.03299152, -0.060088247, -0.015162964,
|
|
-0.001828936, 0.12642565, -0.056757294, 0.013586685,
|
|
0.09232601, -0.035886683, 0.06000002, 0.05229691,
|
|
-0.052580316, -0.082029596, -0.010794592, 0.012947712,
|
|
-0.036429964, -0.085508935, -0.13127148, -0.017744139,
|
|
0.031502828, 0.036232427, -0.031581745, 0.023051167,
|
|
-0.05325106, -0.03421577, 0.028793324, -0.034633752,
|
|
-0.009881397, -0.043551125, -0.018609839, 0.0019097115,
|
|
-0.008799762, 0.056595087, 0.0022273948, 0.055752404};
|
|
|
|
recurrent_to_forget_weights_ = {
|
|
-0.057784554, -0.026057621, -0.068447545, -0.022581743,
|
|
0.14811787, 0.10826372, 0.09471067, 0.03987225,
|
|
-0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
|
|
0.08414449, -0.022036452, -0.00066928595, -0.09203576,
|
|
0.032950465, -0.10985798, -0.023809856, 0.0021431844,
|
|
-0.02196096, -0.00326074, 0.00058621005, -0.074678116,
|
|
-0.06193199, 0.055729095, 0.03736828, 0.020123724,
|
|
0.061878487, -0.04729229, 0.034919553, -0.07585433,
|
|
-0.04421272, -0.044019096, 0.085488975, 0.04058006,
|
|
-0.06890133, -0.030951202, -0.024628663, -0.07672815,
|
|
0.034293607, 0.08556707, -0.05293577, -0.033561368,
|
|
-0.04899627, 0.0241671, 0.015736353, -0.095442444,
|
|
-0.029564252, 0.016493602, -0.035026584, 0.022337519,
|
|
-0.026871363, 0.004780428, 0.0077918363, -0.03601621,
|
|
0.016435321, -0.03263031, -0.09543275, -0.047392778,
|
|
0.013454138, 0.028934088, 0.01685226, -0.086110644,
|
|
-0.046250615, -0.01847454, 0.047608484, 0.07339695,
|
|
0.034546845, -0.04881143, 0.009128804, -0.08802852,
|
|
0.03761666, 0.008096139, -0.014454086, 0.014361001,
|
|
-0.023502491, -0.0011840804, -0.07607001, 0.001856849,
|
|
-0.06509276, -0.006021153, -0.08570962, -0.1451793,
|
|
0.060212336, 0.055259194, 0.06974018, 0.049454916,
|
|
-0.027794661, -0.08077226, -0.016179763, 0.1169753,
|
|
0.17213494, -0.0056326236, -0.053934924, -0.0124349,
|
|
-0.11520337, 0.05409887, 0.088759385, 0.0019655675,
|
|
0.0042065294, 0.03881498, 0.019844765, 0.041858196,
|
|
-0.05695512, 0.047233116, 0.038937137, -0.06542224,
|
|
0.014429736, -0.09719407, 0.13908425, -0.05379757,
|
|
0.012321099, 0.082840554, -0.029899208, 0.044217527,
|
|
0.059855383, 0.07711018, -0.045319796, 0.0948846,
|
|
-0.011724666, -0.0033288454, -0.033542685, -0.04764985,
|
|
-0.13873616, 0.040668588, 0.034832682, -0.015319203,
|
|
-0.018715994, 0.046002675, 0.0599172, -0.043107376,
|
|
0.0294216, -0.002314414, -0.022424703, 0.0030315618,
|
|
0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
|
|
0.12375372, -0.0006038222, 0.029104086, 0.087442465,
|
|
0.052958444, 0.07558703, 0.04817258, 0.044462286,
|
|
-0.015213451, -0.08783778, -0.0561384, -0.003008196,
|
|
0.047060397, -0.002058388, 0.03429439, -0.018839769,
|
|
0.024734668, 0.024614193, -0.042046934, 0.09597743,
|
|
-0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
|
|
-0.02558259, -0.022822596, -0.023273505, -0.02464396,
|
|
-0.10991725, -0.006240552, 0.0074488563, 0.024044557,
|
|
0.04383914, -0.046476185, 0.028658995, 0.060410924,
|
|
0.050786525, 0.009452605, -0.0073054377, -0.024810238,
|
|
0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
|
|
0.015898481, 0.021362653, -0.030262267, 0.016587038,
|
|
-0.011442813, 0.041154444, -0.007631438, -0.03423484,
|
|
-0.010977775, 0.036152758, 0.0066366293, 0.11915515,
|
|
0.02318443, -0.041350313, 0.021485701, -0.10906167,
|
|
-0.028218046, -0.00954771, 0.020531068, -0.11995105,
|
|
-0.03672871, 0.024019798, 0.014255957, -0.05221243,
|
|
-0.00661567, -0.04630967, 0.033188973, 0.10107534,
|
|
-0.014027541, 0.030796422, -0.10270911, -0.035999842,
|
|
0.15443139, 0.07684145, 0.036571592, -0.035900835,
|
|
-0.0034699554, 0.06209149, 0.015920248, -0.031122351,
|
|
-0.03858649, 0.01849943, 0.13872518, 0.01503974,
|
|
0.069941424, -0.06948533, -0.0088794185, 0.061282158,
|
|
-0.047401894, 0.03100163, -0.041533746, -0.10430945,
|
|
0.044574402, -0.01425562, -0.024290353, 0.034563623,
|
|
0.05866852, 0.023947537, -0.09445152, 0.035450947,
|
|
0.02247216, -0.0042998926, 0.061146557, -0.10250651,
|
|
0.020881841, -0.06747029, 0.10062043, -0.0023941975,
|
|
0.03532124, -0.016341697, 0.09685456, -0.016764693,
|
|
0.051808182, 0.05875331, -0.04536488, 0.001626336,
|
|
-0.028892258, -0.01048663, -0.009793449, -0.017093895,
|
|
0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
|
|
-0.001845119, -0.03551521, 0.0018358806, 0.05763657,
|
|
-0.01769146, 0.040995963, 0.02235177, -0.060430344,
|
|
0.11475477, -0.023854522, 0.10071741, 0.0686208,
|
|
-0.014250481, 0.034261297, 0.047418304, 0.08562733,
|
|
-0.030519066, 0.0060542435, 0.014653856, -0.038836084,
|
|
0.04096551, 0.032249358, -0.08355519, -0.026823482,
|
|
0.056386515, -0.010401743, -0.028396193, 0.08507674,
|
|
0.014410365, 0.020995233, 0.17040324, 0.11511526,
|
|
0.02459721, 0.0066619175, 0.025853224, -0.023133837,
|
|
-0.081302024, 0.017264642, -0.009585969, 0.09491168,
|
|
-0.051313367, 0.054532815, -0.014298593, 0.10657464,
|
|
0.007076659, 0.10964551, 0.0409152, 0.008275321,
|
|
-0.07283536, 0.07937492, 0.04192024, -0.1075027};
|
|
|
|
recurrent_to_output_weights_ = {
|
|
0.025825322, -0.05813119, 0.09495884, -0.045984812,
|
|
-0.01255415, -0.0026479573, -0.08196161, -0.054914974,
|
|
-0.0046604523, -0.029587349, -0.044576716, -0.07480124,
|
|
-0.082868785, 0.023254942, 0.027502948, -0.0039728214,
|
|
-0.08683098, -0.08116779, -0.014675607, -0.037924774,
|
|
-0.023314456, -0.007401714, -0.09255757, 0.029460307,
|
|
-0.08829125, -0.005139627, -0.08989442, -0.0555066,
|
|
0.13596267, -0.025062224, -0.048351806, -0.03850004,
|
|
0.07266485, -0.022414139, 0.05940088, 0.075114764,
|
|
0.09597592, -0.010211725, -0.0049794707, -0.011523867,
|
|
-0.025980417, 0.072999895, 0.11091378, -0.081685916,
|
|
0.014416728, 0.043229222, 0.034178585, -0.07530371,
|
|
0.035837382, -0.085607, -0.007721233, -0.03287832,
|
|
-0.043848954, -0.06404588, -0.06632928, -0.073643476,
|
|
0.008214239, -0.045984086, 0.039764922, 0.03474462,
|
|
0.060612556, -0.080590084, 0.049127717, 0.04151091,
|
|
-0.030063879, 0.008801774, -0.023021035, -0.019558564,
|
|
0.05158114, -0.010947698, -0.011825728, 0.0075720972,
|
|
0.0699727, -0.0039981045, 0.069350146, 0.08799282,
|
|
0.016156472, 0.035502106, 0.11695009, 0.006217345,
|
|
0.13392477, -0.037875112, 0.025745004, 0.08940699,
|
|
-0.00924166, 0.0046702605, -0.036598757, -0.08811812,
|
|
0.10522024, -0.032441203, 0.008176899, -0.04454919,
|
|
0.07058152, 0.0067963637, 0.039206743, 0.03259838,
|
|
0.03725492, -0.09515802, 0.013326398, -0.052055415,
|
|
-0.025676316, 0.03198509, -0.015951829, -0.058556724,
|
|
0.036879618, 0.043357447, 0.028362012, -0.05908629,
|
|
0.0059240665, -0.04995891, -0.019187413, 0.0276265,
|
|
-0.01628143, 0.0025863599, 0.08800015, 0.035250366,
|
|
-0.022165963, -0.07328642, -0.009415526, -0.07455109,
|
|
0.11690406, 0.0363299, 0.07411125, 0.042103454,
|
|
-0.009660886, 0.019076364, 0.018299393, -0.046004917,
|
|
0.08891175, 0.0431396, -0.026327137, -0.051502608,
|
|
0.08979574, -0.051670972, 0.04940282, -0.07491107,
|
|
-0.021240504, 0.022596184, -0.034280192, 0.060163025,
|
|
-0.058211457, -0.051837247, -0.01349775, -0.04639988,
|
|
-0.035936575, -0.011681591, 0.064818054, 0.0073146066,
|
|
-0.021745546, -0.043124277, -0.06471268, -0.07053354,
|
|
-0.029321948, -0.05330136, 0.016933719, -0.053782392,
|
|
0.13747959, -0.1361751, -0.11569455, 0.0033329215,
|
|
0.05693899, -0.053219706, 0.063698, 0.07977434,
|
|
-0.07924483, 0.06936997, 0.0034815092, -0.007305279,
|
|
-0.037325785, -0.07251102, -0.033633437, -0.08677009,
|
|
0.091591336, -0.14165086, 0.021752775, 0.019683983,
|
|
0.0011612234, -0.058154266, 0.049996935, 0.0288841,
|
|
-0.0024567875, -0.14345716, 0.010955264, -0.10234828,
|
|
0.1183656, -0.0010731248, -0.023590032, -0.072285876,
|
|
-0.0724771, -0.026382286, -0.0014920527, 0.042667855,
|
|
0.0018776858, 0.02986552, 0.009814309, 0.0733756,
|
|
0.12289186, 0.018043943, -0.0458958, 0.049412545,
|
|
0.033632483, 0.05495232, 0.036686596, -0.013781798,
|
|
-0.010036754, 0.02576849, -0.08307328, 0.010112348,
|
|
0.042521734, -0.05869831, -0.071689695, 0.03876447,
|
|
-0.13275425, -0.0352966, -0.023077697, 0.10285965,
|
|
0.084736146, 0.15568255, -0.00040734606, 0.027835453,
|
|
-0.10292561, -0.032401145, 0.10053256, -0.026142767,
|
|
-0.08271222, -0.0030240538, -0.016368777, 0.1070414,
|
|
0.042672627, 0.013456989, -0.0437609, -0.022309763,
|
|
0.11576483, 0.04108048, 0.061026827, -0.0190714,
|
|
-0.0869359, 0.037901703, 0.0610107, 0.07202949,
|
|
0.01675338, 0.086139716, -0.08795751, -0.014898893,
|
|
-0.023771819, -0.01965048, 0.007955471, -0.043740474,
|
|
0.03346837, -0.10549954, 0.090567775, 0.042013682,
|
|
-0.03176985, 0.12569028, -0.02421228, -0.029526481,
|
|
0.023851605, 0.031539805, 0.05292009, -0.02344001,
|
|
-0.07811758, -0.08834428, 0.10094801, 0.16594367,
|
|
-0.06861939, -0.021256343, -0.041093912, -0.06669611,
|
|
0.035498552, 0.021757556, -0.09302526, -0.015403468,
|
|
-0.06614931, -0.051798206, -0.013874718, 0.03630673,
|
|
0.010412845, -0.08077351, 0.046185967, 0.0035662893,
|
|
0.03541868, -0.094149634, -0.034814864, 0.003128424,
|
|
-0.020674974, -0.03944324, -0.008110165, -0.11113267,
|
|
0.08484226, 0.043586485, 0.040582247, 0.0968012,
|
|
-0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
|
|
0.0326779, 0.041296225, 0.09164146, -0.047743853,
|
|
-0.015952192, -0.034451712, 0.084197424, -0.05347844,
|
|
-0.11768019, 0.085926116, -0.08251791, -0.045081906,
|
|
0.0948852, 0.068401024, 0.024856757, 0.06978981,
|
|
-0.057309967, -0.012775832, -0.0032452994, 0.01977615,
|
|
-0.041040014, -0.024264973, 0.063464895, 0.05431621,
|
|
};
|
|
|
|
cell_to_input_weights_ = {
|
|
0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
|
|
-0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
|
|
-0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
|
|
0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
|
|
|
|
cell_to_forget_weights_ = {
|
|
-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
|
|
-0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
|
|
-0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
|
|
0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
|
|
|
|
cell_to_output_weights_ = {
|
|
0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
|
|
-0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
|
|
-0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
|
|
0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
|
|
|
|
projection_weights_ = {
|
|
-0.009802181, 0.09401916, 0.0717386, -0.13895074,
|
|
0.09641832, 0.060420845, 0.08539281, 0.054285463,
|
|
0.061395317, 0.034448683, -0.042991187, 0.019801661,
|
|
-0.16840284, -0.015726732, -0.23041931, -0.024478018,
|
|
-0.10959692, -0.013875541, 0.18600968, -0.061274476,
|
|
0.0138165, -0.08160894, -0.07661644, 0.032372914,
|
|
0.16169067, 0.22465782, -0.03993472, -0.004017731,
|
|
0.08633481, -0.28869787, 0.08682067, 0.17240396,
|
|
0.014975425, 0.056431185, 0.031037588, 0.16702051,
|
|
0.0077946745, 0.15140012, 0.29405436, 0.120285,
|
|
-0.188994, -0.027265169, 0.043389652, -0.022061434,
|
|
0.014777949, -0.20203483, 0.094781205, 0.19100232,
|
|
0.13987629, -0.036132768, -0.06426278, -0.05108664,
|
|
0.13221376, 0.009441198, -0.16715929, 0.15859416,
|
|
-0.040437475, 0.050779544, -0.022187516, 0.012166504,
|
|
0.027685808, -0.07675938, -0.0055694645, -0.09444123,
|
|
0.0046453946, 0.050794356, 0.10770313, -0.20790008,
|
|
-0.07149004, -0.11425117, 0.008225835, -0.035802525,
|
|
0.14374903, 0.15262283, 0.048710253, 0.1847461,
|
|
-0.007487823, 0.11000021, -0.09542012, 0.22619456,
|
|
-0.029149994, 0.08527916, 0.009043713, 0.0042746216,
|
|
0.016261552, 0.022461696, 0.12689082, -0.043589946,
|
|
-0.12035478, -0.08361797, -0.050666027, -0.1248618,
|
|
-0.1275799, -0.071875185, 0.07377272, 0.09944291,
|
|
-0.18897448, -0.1593054, -0.06526116, -0.040107165,
|
|
-0.004618631, -0.067624845, -0.007576253, 0.10727444,
|
|
0.041546922, -0.20424393, 0.06907816, 0.050412357,
|
|
0.00724631, 0.039827548, 0.12449835, 0.10747581,
|
|
0.13708383, 0.09134148, -0.12617786, -0.06428341,
|
|
0.09956831, 0.1208086, -0.14676677, -0.0727722,
|
|
0.1126304, 0.010139365, 0.015571211, -0.038128063,
|
|
0.022913318, -0.042050496, 0.16842307, -0.060597885,
|
|
0.10531834, -0.06411776, -0.07451711, -0.03410368,
|
|
-0.13393489, 0.06534304, 0.003620307, 0.04490757,
|
|
0.05970546, 0.05197996, 0.02839995, 0.10434969,
|
|
-0.013699693, -0.028353551, -0.07260381, 0.047201227,
|
|
-0.024575593, -0.036445823, 0.07155557, 0.009672501,
|
|
-0.02328883, 0.009533515, -0.03606021, -0.07421458,
|
|
-0.028082801, -0.2678904, -0.13221288, 0.18419984,
|
|
-0.13012612, -0.014588381, -0.035059117, -0.04824723,
|
|
0.07830115, -0.056184657, 0.03277091, 0.025466874,
|
|
0.14494097, -0.12522776, -0.098633975, -0.10766018,
|
|
-0.08317623, 0.08594209, 0.07749552, 0.039474737,
|
|
0.1776665, -0.07409566, -0.0477268, 0.29323658,
|
|
0.10801441, 0.1154011, 0.013952499, 0.10739139,
|
|
0.10708251, -0.051456142, 0.0074137426, -0.10430189,
|
|
0.10034707, 0.045594677, 0.0635285, -0.0715442,
|
|
-0.089667566, -0.10811871, 0.00026344223, 0.08298446,
|
|
-0.009525053, 0.006585689, -0.24567553, -0.09450807,
|
|
0.09648481, 0.026996298, -0.06419476, -0.04752702,
|
|
-0.11063944, -0.23441927, -0.17608605, -0.052156363,
|
|
0.067035615, 0.19271925, -0.0032889997, -0.043264326,
|
|
0.09663576, -0.057112187, -0.10100678, 0.0628376,
|
|
0.04447668, 0.017961001, -0.10094388, -0.10190601,
|
|
0.18335468, 0.10494553, -0.052095775, -0.0026118709,
|
|
0.10539724, -0.04383912, -0.042349473, 0.08438151,
|
|
-0.1947263, 0.02251204, 0.11216432, -0.10307853,
|
|
0.17351969, -0.039091777, 0.08066188, -0.00561982,
|
|
0.12633002, 0.11335965, -0.0088127935, -0.019777594,
|
|
0.06864014, -0.059751723, 0.016233567, -0.06894641,
|
|
-0.28651384, -0.004228674, 0.019708522, -0.16305895,
|
|
-0.07468996, -0.0855457, 0.099339016, -0.07580735,
|
|
-0.13775392, 0.08434318, 0.08330512, -0.12131499,
|
|
0.031935584, 0.09180414, -0.08876437, -0.08049874,
|
|
0.008753825, 0.03498998, 0.030215185, 0.03907079,
|
|
0.089751154, 0.029194152, -0.03337423, -0.019092513,
|
|
0.04331237, 0.04299654, -0.036394123, -0.12915532,
|
|
0.09793732, 0.07512415, -0.11319543, -0.032502122,
|
|
0.15661901, 0.07671967, -0.005491124, -0.19379048,
|
|
-0.218606, 0.21448623, 0.017840758, 0.1416943,
|
|
-0.07051762, 0.19488361, 0.02664691, -0.18104725,
|
|
-0.09334311, 0.15026465, -0.15493552, -0.057762887,
|
|
-0.11604192, -0.262013, -0.01391798, 0.012185008,
|
|
0.11156489, -0.07483202, 0.06693364, -0.26151478,
|
|
0.046425626, 0.036540434, -0.16435726, 0.17338543,
|
|
-0.21401681, -0.11385144, -0.08283257, -0.069031075,
|
|
0.030635102, 0.010969227, 0.11109743, 0.010919218,
|
|
0.027526086, 0.13519906, 0.01891392, -0.046839405,
|
|
-0.040167913, 0.017953383, -0.09700955, 0.0061885654,
|
|
-0.07000971, 0.026893595, -0.038844477, 0.14543656};
|
|
|
|
lstm_input_ = {
|
|
{// Batch0: 4 (input_sequence_size) * 5 (n_input)
|
|
0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
|
|
0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
|
|
0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
|
|
0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
|
|
|
|
{// Batch1: 4 (input_sequence_size) * 5 (n_input)
|
|
0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
|
|
0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
|
|
0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
|
|
0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
|
|
};
|
|
|
|
lstm_golden_output_ = {
|
|
{// Batch0: 4 (input_sequence_size) * 16 (n_output)
|
|
-0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
|
|
-0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
|
|
-0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
|
|
0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
|
|
-0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
|
|
-0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
|
|
0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
|
|
0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
|
|
0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
|
|
0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
|
|
-0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
|
|
-0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
|
|
0.0286833, 0.00824207, 0.0264887, 0.0305169},
|
|
{// Batch1: 4 (input_sequence_size) * 16 (n_output)
|
|
-0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
|
|
-0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
|
|
0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
|
|
0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
|
|
-0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
|
|
-0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
|
|
0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
|
|
0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
|
|
0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
|
|
0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
|
|
-0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
|
|
-0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
|
|
0.0412031, 0.0118723, 0.0239643, 0.0394009}};
|
|
}
|
|
};
|
|
|
|
TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|
LstmBlackBoxTest) {
|
|
const int n_batch = 2;
|
|
const int n_input = 5;
|
|
const int n_cell = 20;
|
|
const int n_output = 16;
|
|
const int sequence_length = 4;
|
|
|
|
UnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/true,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{n_cell}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{n_output, n_cell}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
});
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
lstm.SetProjectionWeights(projection_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
|
}
|
|
|
|
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|
HybridLstmBlackBoxTestUint8) {
|
|
const int n_batch = 2;
|
|
const int n_input = 5;
|
|
const int n_cell = 20;
|
|
const int n_output = 16;
|
|
const int sequence_length = 4;
|
|
if (GetParam()) {
|
|
return;
|
|
}
|
|
HybridUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/true,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{n_cell}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{n_output, n_cell}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
},
|
|
TensorType_UINT8, GetParam());
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
lstm.SetProjectionWeights(projection_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
|
}
|
|
|
|
TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
|
|
HybridLstmBlackBoxTestInt8) {
|
|
if (GetParam()) {
|
|
return;
|
|
}
|
|
const int n_batch = 2;
|
|
const int n_input = 5;
|
|
const int n_cell = 20;
|
|
const int n_output = 16;
|
|
const int sequence_length = 4;
|
|
|
|
HybridUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/true,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{n_cell}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{n_output, n_cell}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
},
|
|
TensorType_INT8, GetParam());
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
lstm.SetProjectionWeights(projection_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
|
}
|
|
|
|
class NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest
|
|
: public BaseUnidirectionalLstmTest {
|
|
void SetUp() override {
|
|
input_to_input_weights_ = {
|
|
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
|
|
0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
|
|
-0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
|
|
-0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
|
|
-0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
|
|
-0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
|
|
-0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
|
|
0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
|
|
0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
|
|
0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
|
|
-0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
|
|
0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
|
|
-0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
|
|
-0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
|
|
-0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
|
|
0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
|
|
-0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
|
|
-0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
|
|
-0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
|
|
-0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
|
|
|
|
input_to_forget_weights_ = {
|
|
-0.0018401089, -0.004852237, 0.03698424, 0.014181704,
|
|
0.028273236, -0.016726194, -0.05249759, -0.10204261,
|
|
0.00861066, -0.040979505, -0.009899187, 0.01923892,
|
|
-0.028177269, -0.08535103, -0.14585495, 0.10662567,
|
|
-0.01909731, -0.017883534, -0.0047269356, -0.045103323,
|
|
0.0030784295, 0.076784775, 0.07463696, 0.094531395,
|
|
0.0814421, -0.12257899, -0.033945758, -0.031303465,
|
|
0.045630626, 0.06843887, -0.13492945, -0.012480007,
|
|
-0.0811829, -0.07224499, -0.09628791, 0.045100946,
|
|
0.0012300825, 0.013964662, 0.099372394, 0.02543059,
|
|
0.06958324, 0.034257296, 0.0482646, 0.06267997,
|
|
0.052625068, 0.12784666, 0.07077897, 0.025725935,
|
|
0.04165009, 0.07241905, 0.018668644, -0.037377294,
|
|
-0.06277783, -0.08833636, -0.040120605, -0.011405586,
|
|
-0.007808335, -0.010301386, -0.005102167, 0.027717464,
|
|
0.05483423, 0.11449111, 0.11289652, 0.10939839,
|
|
0.13396506, -0.08402166, -0.01901462, -0.044678304,
|
|
-0.07720565, 0.014350063, -0.11757958, -0.0652038,
|
|
-0.08185733, -0.076754324, -0.092614375, 0.10405491,
|
|
0.052960336, 0.035755895, 0.035839386, -0.012540553,
|
|
0.036881298, 0.02913376, 0.03420159, 0.05448447,
|
|
-0.054523353, 0.02582715, 0.02327355, -0.011857179,
|
|
-0.0011980024, -0.034641717, -0.026125094, -0.17582615,
|
|
-0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
|
|
-8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
|
|
|
|
input_to_cell_weights_ = {
|
|
-0.04580283, -0.09549462, -0.032418985, -0.06454633,
|
|
-0.043528453, 0.043018587, -0.049152344, -0.12418144,
|
|
-0.078985475, -0.07596889, 0.019484362, -0.11434962,
|
|
-0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
|
|
-0.025034338, -0.0028890965, 0.048929527, 0.06235075,
|
|
0.10665918, -0.032036792, -0.08505916, -0.10843358,
|
|
-0.13002433, -0.036816437, -0.02130134, -0.016518239,
|
|
0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
|
|
-0.10652836, -0.1037554, -0.13056071, -0.03266643,
|
|
-0.033702414, -0.006473424, -0.04611692, 0.014419339,
|
|
-0.025174323, 0.0396852, 0.081777506, 0.06157468,
|
|
0.10210095, -0.009658194, 0.046511717, 0.03603906,
|
|
0.0069369148, 0.015960095, -0.06507666, 0.09551598,
|
|
0.053568836, 0.06408714, 0.12835667, -0.008714329,
|
|
-0.20211966, -0.12093674, 0.029450472, 0.2849013,
|
|
-0.029227901, 0.1164364, -0.08560263, 0.09941786,
|
|
-0.036999565, -0.028842626, -0.0033637602, -0.017012902,
|
|
-0.09720865, -0.11193351, -0.029155117, -0.017936034,
|
|
-0.009768936, -0.04223324, -0.036159635, 0.06505112,
|
|
-0.021742892, -0.023377212, -0.07221364, -0.06430552,
|
|
0.05453865, 0.091149814, 0.06387331, 0.007518393,
|
|
0.055960953, 0.069779344, 0.046411168, 0.10509911,
|
|
0.07463894, 0.0075130584, 0.012850982, 0.04555431,
|
|
0.056955688, 0.06555285, 0.050801456, -0.009862683,
|
|
0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
|
|
|
|
input_to_output_weights_ = {
|
|
-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
|
|
-0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
|
|
0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
|
|
-0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
|
|
-0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
|
|
0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
|
|
-0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
|
|
-0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
|
|
-0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
|
|
-0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
|
|
0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
|
|
0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
|
|
0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
|
|
-0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
|
|
0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
|
|
0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
|
|
-0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
|
|
0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
|
|
-0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
|
|
-0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
|
|
|
|
input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
|
|
0.053110216, -0.06928846, -0.13942584, -0.11816189,
|
|
0.19483899, 0.03652339, -0.10250295, 0.036714908,
|
|
-0.18426876, 0.036065217, 0.21810818, 0.02383196,
|
|
-0.043370757, 0.08690144, -0.04444982, 0.00030581196};
|
|
|
|
forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
|
|
0.11098921, 0.15378423, 0.09263801, 0.09790885,
|
|
0.09508917, 0.061199076, 0.07665568, -0.015443159,
|
|
-0.03499149, 0.046190713, 0.08895977, 0.10899629,
|
|
0.40694186, 0.06030037, 0.012413437, -0.06108739};
|
|
|
|
cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
|
|
-0.1483596, -0.10639995, -0.091433935, 0.058573797,
|
|
-0.06809782, -0.07889636, -0.043246906, -0.09829136,
|
|
-0.4279842, 0.034901652, 0.18797937, 0.0075234566,
|
|
0.016178843, 0.1749513, 0.13975595, 0.92058027};
|
|
|
|
output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
|
|
0.027195795, 0.35373217, -0.018957434, 0.008907322,
|
|
-0.0762701, 0.12018895, 0.04216877, 0.0022856654,
|
|
0.040952638, 0.3147856, 0.08225149, -0.057416286,
|
|
-0.14995944, -0.008040261, 0.13208859, 0.029760877};
|
|
|
|
recurrent_to_input_weights_ = {
|
|
-0.001374326, -0.078856036, 0.10672688, 0.029162422,
|
|
-0.11585556, 0.02557986, -0.13446963, -0.035785314,
|
|
-0.01244275, 0.025961924, -0.02337298, -0.044228926,
|
|
-0.055839065, -0.046598054, -0.010546039, -0.06900766,
|
|
0.027239809, 0.022582639, -0.013296484, -0.05459212,
|
|
0.08981, -0.045407712, 0.08682226, -0.06867011,
|
|
-0.14390695, -0.02916037, 0.000996957, 0.091420636,
|
|
0.14283475, -0.07390571, -0.06402044, 0.062524505,
|
|
-0.093129106, 0.04860203, -0.08364217, -0.08119002,
|
|
0.009352075, 0.22920375, 0.0016303885, 0.11583097,
|
|
-0.13732095, 0.012405723, -0.07551853, 0.06343048,
|
|
0.12162708, -0.031923793, -0.014335606, 0.01790974,
|
|
-0.10650317, -0.0724401, 0.08554849, -0.05727212,
|
|
0.06556731, -0.042729504, -0.043227166, 0.011683251,
|
|
-0.013082158, -0.029302018, -0.010899579, -0.062036745,
|
|
-0.022509435, -0.00964907, -0.01567329, 0.04260106,
|
|
-0.07787477, -0.11576462, 0.017356863, 0.048673786,
|
|
-0.017577527, -0.05527947, -0.082487635, -0.040137455,
|
|
-0.10820036, -0.04666372, 0.022746278, -0.07851417,
|
|
0.01068115, 0.032956902, 0.022433773, 0.0026891115,
|
|
0.08944216, -0.0685835, 0.010513544, 0.07228705,
|
|
0.02032331, -0.059686817, -0.0005566496, -0.086984694,
|
|
0.040414046, -0.1380399, 0.094208956, -0.05722982,
|
|
0.012092817, -0.04989123, -0.086576, -0.003399834,
|
|
-0.04696032, -0.045747425, 0.10091314, 0.048676282,
|
|
-0.029037097, 0.031399418, -0.0040285117, 0.047237843,
|
|
0.09504992, 0.041799378, -0.049185462, -0.031518843,
|
|
-0.10516937, 0.026374253, 0.10058866, -0.0033195973,
|
|
-0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
|
|
-0.10167381, 0.042500053, -0.01447153, 0.06464186,
|
|
-0.017142897, 0.03312627, 0.009205989, 0.024138335,
|
|
-0.011337001, 0.035530265, -0.010912711, 0.0706555,
|
|
-0.005894094, 0.051841937, -0.1401738, -0.02351249,
|
|
0.0365468, 0.07590991, 0.08838724, 0.021681072,
|
|
-0.10086113, 0.019608743, -0.06195883, 0.077335775,
|
|
0.023646897, -0.095322326, 0.02233014, 0.09756986,
|
|
-0.048691444, -0.009579111, 0.07595467, 0.11480546,
|
|
-0.09801813, 0.019894179, 0.08502348, 0.004032281,
|
|
0.037211012, 0.068537936, -0.048005626, -0.091520436,
|
|
-0.028379958, -0.01556313, 0.06554592, -0.045599163,
|
|
-0.01672207, -0.020169014, -0.011877351, -0.20212261,
|
|
0.010889619, 0.0047078193, 0.038385306, 0.08540671,
|
|
-0.017140968, -0.0035865551, 0.016678626, 0.005633034,
|
|
0.015963363, 0.00871737, 0.060130805, 0.028611384,
|
|
0.10109069, -0.015060172, -0.07894427, 0.06401885,
|
|
0.011584063, -0.024466386, 0.0047652307, -0.09041358,
|
|
0.030737216, -0.0046374933, 0.14215417, -0.11823516,
|
|
0.019899689, 0.006106124, -0.027092824, 0.0786356,
|
|
0.05052217, -0.058925, -0.011402121, -0.024987547,
|
|
-0.0013661642, -0.06832946, -0.015667673, -0.1083353,
|
|
-0.00096863037, -0.06988685, -0.053350925, -0.027275559,
|
|
-0.033664223, -0.07978348, -0.025200296, -0.017207067,
|
|
-0.058403496, -0.055697463, 0.005798788, 0.12965427,
|
|
-0.062582195, 0.0013350133, -0.10482091, 0.0379771,
|
|
0.072521195, -0.0029455067, -0.13797039, -0.03628521,
|
|
0.013806405, -0.017858358, -0.01008298, -0.07700066,
|
|
-0.017081132, 0.019358726, 0.0027079724, 0.004635139,
|
|
0.062634714, -0.02338735, -0.039547626, -0.02050681,
|
|
0.03385117, -0.083611414, 0.002862572, -0.09421313,
|
|
0.058618143, -0.08598433, 0.00972939, 0.023867095,
|
|
-0.053934585, -0.023203006, 0.07452513, -0.048767887,
|
|
-0.07314807, -0.056307215, -0.10433547, -0.06440842,
|
|
0.04328182, 0.04389765, -0.020006588, -0.09076438,
|
|
-0.11652589, -0.021705797, 0.03345259, -0.010329105,
|
|
-0.025767034, 0.013057034, -0.07316461, -0.10145612,
|
|
0.06358255, 0.18531723, 0.07759293, 0.12006465,
|
|
0.1305557, 0.058638252, -0.03393652, 0.09622831,
|
|
-0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
|
|
-0.005644518, 0.06857898, -0.12598175, -0.035084512,
|
|
0.03156317, -0.12794146, -0.031963028, 0.04692781,
|
|
0.030070418, 0.0071660685, -0.095516115, -0.004643372,
|
|
0.040170413, -0.062104587, -0.0037324072, 0.0554317,
|
|
0.08184801, -0.019164372, 0.06791302, 0.034257166,
|
|
-0.10307039, 0.021943003, 0.046745934, 0.0790918,
|
|
-0.0265588, -0.007824208, 0.042546265, -0.00977924,
|
|
-0.0002440307, -0.017384544, -0.017990116, 0.12252321,
|
|
-0.014512694, -0.08251313, 0.08861942, 0.13589665,
|
|
0.026351685, 0.012641483, 0.07466548, 0.044301085,
|
|
-0.045414884, -0.051112458, 0.03444247, -0.08502782,
|
|
-0.04106223, -0.028126027, 0.028473156, 0.10467447};
|
|
|
|
recurrent_to_cell_weights_ = {
|
|
-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
|
|
0.055647098, -0.05713207, -0.05626563, 0.005559383,
|
|
0.03375411, -0.025757805, -0.088049285, 0.06017052,
|
|
-0.06570978, 0.007384076, 0.035123326, -0.07920549,
|
|
0.053676967, 0.044480428, -0.07663568, 0.0071805613,
|
|
0.08089997, 0.05143358, 0.038261272, 0.03339287,
|
|
-0.027673481, 0.044746667, 0.028349208, 0.020090483,
|
|
-0.019443132, -0.030755889, -0.0040000007, 0.04465846,
|
|
-0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
|
|
-0.10893326, 0.076739706, -0.08509834, -0.027997585,
|
|
0.037871376, 0.01449768, -0.09002357, -0.06111149,
|
|
-0.046195522, 0.0422062, -0.005683705, -0.1253618,
|
|
-0.012925729, -0.04890792, 0.06985068, 0.037654128,
|
|
0.03398274, -0.004781977, 0.007032333, -0.031787455,
|
|
0.010868644, -0.031489216, 0.09525667, 0.013939797,
|
|
0.0058680447, 0.0167067, 0.02668468, -0.04797466,
|
|
-0.048885044, -0.12722108, 0.035304096, 0.06554885,
|
|
0.00972396, -0.039238118, -0.05159735, -0.11329045,
|
|
0.1613692, -0.03750952, 0.06529313, -0.071974665,
|
|
-0.11769596, 0.015524369, -0.0013754242, -0.12446318,
|
|
0.02786344, -0.014179351, 0.005264273, 0.14376344,
|
|
0.015983658, 0.03406988, -0.06939408, 0.040699873,
|
|
0.02111075, 0.09669095, 0.041345075, -0.08316494,
|
|
-0.07684199, -0.045768797, 0.032298047, -0.041805092,
|
|
0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
|
|
-0.024950314, 0.11574242, 0.04508852, -0.04335324,
|
|
0.06760663, -0.027437469, 0.07216407, 0.06977076,
|
|
-0.05438599, 0.034033038, -0.028602652, 0.05346137,
|
|
0.043184172, -0.037189785, 0.10420091, 0.00882477,
|
|
-0.054019816, -0.074273005, -0.030617684, -0.0028467078,
|
|
0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
|
|
0.04361412, -0.007001822, 0.09631092, -0.06702025,
|
|
-0.042049985, -0.035070654, -0.04103342, -0.10273396,
|
|
0.0544271, 0.037184782, -0.13150354, -0.0058036847,
|
|
-0.008264958, 0.042035464, 0.05891794, 0.029673764,
|
|
0.0063542654, 0.044788733, 0.054816857, 0.062257513,
|
|
-0.00093483756, 0.048938446, -0.004952862, -0.007730018,
|
|
-0.04043371, -0.017094059, 0.07229206, -0.023670016,
|
|
-0.052195564, -0.025616996, -0.01520939, 0.045104615,
|
|
-0.007376126, 0.003533447, 0.006570588, 0.056037236,
|
|
0.12436656, 0.051817212, 0.028532185, -0.08686856,
|
|
0.11868599, 0.07663395, -0.07323171, 0.03463402,
|
|
-0.050708205, -0.04458982, -0.11590894, 0.021273347,
|
|
0.1251325, -0.15313013, -0.12224372, 0.17228661,
|
|
0.023029093, 0.086124025, 0.006445803, -0.03496501,
|
|
0.028332196, 0.04449512, -0.042436164, -0.026587414,
|
|
-0.006041347, -0.09292539, -0.05678812, 0.03897832,
|
|
0.09465633, 0.008115513, -0.02171956, 0.08304309,
|
|
0.071401566, 0.019622514, 0.032163795, -0.004167056,
|
|
0.02295182, 0.030739572, 0.056506045, 0.004612461,
|
|
0.06524936, 0.059999723, 0.046395954, -0.0045512207,
|
|
-0.1335546, -0.030136576, 0.11584653, -0.014678886,
|
|
0.0020118146, -0.09688814, -0.0790206, 0.039770417,
|
|
-0.0329582, 0.07922767, 0.029322514, 0.026405897,
|
|
0.04207835, -0.07073373, 0.063781224, 0.0859677,
|
|
-0.10925287, -0.07011058, 0.048005477, 0.03438226,
|
|
-0.09606514, -0.006669445, -0.043381985, 0.04240257,
|
|
-0.06955775, -0.06769346, 0.043903265, -0.026784198,
|
|
-0.017840602, 0.024307009, -0.040079936, -0.019946516,
|
|
0.045318738, -0.12233574, 0.026170589, 0.0074471775,
|
|
0.15978073, 0.10185836, 0.10298046, -0.015476589,
|
|
-0.039390966, -0.072174534, 0.0739445, -0.1211869,
|
|
-0.0347889, -0.07943156, 0.014809798, -0.12412325,
|
|
-0.0030663363, 0.039695457, 0.0647603, -0.08291318,
|
|
-0.018529687, -0.004423833, 0.0037507233, 0.084633216,
|
|
-0.01514876, -0.056505352, -0.012800942, -0.06994386,
|
|
0.012962922, -0.031234352, 0.07029052, 0.016418684,
|
|
0.03618972, 0.055686004, -0.08663945, -0.017404709,
|
|
-0.054761406, 0.029065743, 0.052404847, 0.020238016,
|
|
0.0048197987, -0.0214882, 0.07078733, 0.013016777,
|
|
0.06262858, 0.009184685, 0.020785125, -0.043904778,
|
|
-0.0270329, -0.03299152, -0.060088247, -0.015162964,
|
|
-0.001828936, 0.12642565, -0.056757294, 0.013586685,
|
|
0.09232601, -0.035886683, 0.06000002, 0.05229691,
|
|
-0.052580316, -0.082029596, -0.010794592, 0.012947712,
|
|
-0.036429964, -0.085508935, -0.13127148, -0.017744139,
|
|
0.031502828, 0.036232427, -0.031581745, 0.023051167,
|
|
-0.05325106, -0.03421577, 0.028793324, -0.034633752,
|
|
-0.009881397, -0.043551125, -0.018609839, 0.0019097115,
|
|
-0.008799762, 0.056595087, 0.0022273948, 0.055752404};
|
|
|
|
recurrent_to_forget_weights_ = {
|
|
-0.057784554, -0.026057621, -0.068447545, -0.022581743,
|
|
0.14811787, 0.10826372, 0.09471067, 0.03987225,
|
|
-0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
|
|
0.08414449, -0.022036452, -0.00066928595, -0.09203576,
|
|
0.032950465, -0.10985798, -0.023809856, 0.0021431844,
|
|
-0.02196096, -0.00326074, 0.00058621005, -0.074678116,
|
|
-0.06193199, 0.055729095, 0.03736828, 0.020123724,
|
|
0.061878487, -0.04729229, 0.034919553, -0.07585433,
|
|
-0.04421272, -0.044019096, 0.085488975, 0.04058006,
|
|
-0.06890133, -0.030951202, -0.024628663, -0.07672815,
|
|
0.034293607, 0.08556707, -0.05293577, -0.033561368,
|
|
-0.04899627, 0.0241671, 0.015736353, -0.095442444,
|
|
-0.029564252, 0.016493602, -0.035026584, 0.022337519,
|
|
-0.026871363, 0.004780428, 0.0077918363, -0.03601621,
|
|
0.016435321, -0.03263031, -0.09543275, -0.047392778,
|
|
0.013454138, 0.028934088, 0.01685226, -0.086110644,
|
|
-0.046250615, -0.01847454, 0.047608484, 0.07339695,
|
|
0.034546845, -0.04881143, 0.009128804, -0.08802852,
|
|
0.03761666, 0.008096139, -0.014454086, 0.014361001,
|
|
-0.023502491, -0.0011840804, -0.07607001, 0.001856849,
|
|
-0.06509276, -0.006021153, -0.08570962, -0.1451793,
|
|
0.060212336, 0.055259194, 0.06974018, 0.049454916,
|
|
-0.027794661, -0.08077226, -0.016179763, 0.1169753,
|
|
0.17213494, -0.0056326236, -0.053934924, -0.0124349,
|
|
-0.11520337, 0.05409887, 0.088759385, 0.0019655675,
|
|
0.0042065294, 0.03881498, 0.019844765, 0.041858196,
|
|
-0.05695512, 0.047233116, 0.038937137, -0.06542224,
|
|
0.014429736, -0.09719407, 0.13908425, -0.05379757,
|
|
0.012321099, 0.082840554, -0.029899208, 0.044217527,
|
|
0.059855383, 0.07711018, -0.045319796, 0.0948846,
|
|
-0.011724666, -0.0033288454, -0.033542685, -0.04764985,
|
|
-0.13873616, 0.040668588, 0.034832682, -0.015319203,
|
|
-0.018715994, 0.046002675, 0.0599172, -0.043107376,
|
|
0.0294216, -0.002314414, -0.022424703, 0.0030315618,
|
|
0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
|
|
0.12375372, -0.0006038222, 0.029104086, 0.087442465,
|
|
0.052958444, 0.07558703, 0.04817258, 0.044462286,
|
|
-0.015213451, -0.08783778, -0.0561384, -0.003008196,
|
|
0.047060397, -0.002058388, 0.03429439, -0.018839769,
|
|
0.024734668, 0.024614193, -0.042046934, 0.09597743,
|
|
-0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
|
|
-0.02558259, -0.022822596, -0.023273505, -0.02464396,
|
|
-0.10991725, -0.006240552, 0.0074488563, 0.024044557,
|
|
0.04383914, -0.046476185, 0.028658995, 0.060410924,
|
|
0.050786525, 0.009452605, -0.0073054377, -0.024810238,
|
|
0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
|
|
0.015898481, 0.021362653, -0.030262267, 0.016587038,
|
|
-0.011442813, 0.041154444, -0.007631438, -0.03423484,
|
|
-0.010977775, 0.036152758, 0.0066366293, 0.11915515,
|
|
0.02318443, -0.041350313, 0.021485701, -0.10906167,
|
|
-0.028218046, -0.00954771, 0.020531068, -0.11995105,
|
|
-0.03672871, 0.024019798, 0.014255957, -0.05221243,
|
|
-0.00661567, -0.04630967, 0.033188973, 0.10107534,
|
|
-0.014027541, 0.030796422, -0.10270911, -0.035999842,
|
|
0.15443139, 0.07684145, 0.036571592, -0.035900835,
|
|
-0.0034699554, 0.06209149, 0.015920248, -0.031122351,
|
|
-0.03858649, 0.01849943, 0.13872518, 0.01503974,
|
|
0.069941424, -0.06948533, -0.0088794185, 0.061282158,
|
|
-0.047401894, 0.03100163, -0.041533746, -0.10430945,
|
|
0.044574402, -0.01425562, -0.024290353, 0.034563623,
|
|
0.05866852, 0.023947537, -0.09445152, 0.035450947,
|
|
0.02247216, -0.0042998926, 0.061146557, -0.10250651,
|
|
0.020881841, -0.06747029, 0.10062043, -0.0023941975,
|
|
0.03532124, -0.016341697, 0.09685456, -0.016764693,
|
|
0.051808182, 0.05875331, -0.04536488, 0.001626336,
|
|
-0.028892258, -0.01048663, -0.009793449, -0.017093895,
|
|
0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
|
|
-0.001845119, -0.03551521, 0.0018358806, 0.05763657,
|
|
-0.01769146, 0.040995963, 0.02235177, -0.060430344,
|
|
0.11475477, -0.023854522, 0.10071741, 0.0686208,
|
|
-0.014250481, 0.034261297, 0.047418304, 0.08562733,
|
|
-0.030519066, 0.0060542435, 0.014653856, -0.038836084,
|
|
0.04096551, 0.032249358, -0.08355519, -0.026823482,
|
|
0.056386515, -0.010401743, -0.028396193, 0.08507674,
|
|
0.014410365, 0.020995233, 0.17040324, 0.11511526,
|
|
0.02459721, 0.0066619175, 0.025853224, -0.023133837,
|
|
-0.081302024, 0.017264642, -0.009585969, 0.09491168,
|
|
-0.051313367, 0.054532815, -0.014298593, 0.10657464,
|
|
0.007076659, 0.10964551, 0.0409152, 0.008275321,
|
|
-0.07283536, 0.07937492, 0.04192024, -0.1075027};
|
|
|
|
recurrent_to_output_weights_ = {
|
|
0.025825322, -0.05813119, 0.09495884, -0.045984812,
|
|
-0.01255415, -0.0026479573, -0.08196161, -0.054914974,
|
|
-0.0046604523, -0.029587349, -0.044576716, -0.07480124,
|
|
-0.082868785, 0.023254942, 0.027502948, -0.0039728214,
|
|
-0.08683098, -0.08116779, -0.014675607, -0.037924774,
|
|
-0.023314456, -0.007401714, -0.09255757, 0.029460307,
|
|
-0.08829125, -0.005139627, -0.08989442, -0.0555066,
|
|
0.13596267, -0.025062224, -0.048351806, -0.03850004,
|
|
0.07266485, -0.022414139, 0.05940088, 0.075114764,
|
|
0.09597592, -0.010211725, -0.0049794707, -0.011523867,
|
|
-0.025980417, 0.072999895, 0.11091378, -0.081685916,
|
|
0.014416728, 0.043229222, 0.034178585, -0.07530371,
|
|
0.035837382, -0.085607, -0.007721233, -0.03287832,
|
|
-0.043848954, -0.06404588, -0.06632928, -0.073643476,
|
|
0.008214239, -0.045984086, 0.039764922, 0.03474462,
|
|
0.060612556, -0.080590084, 0.049127717, 0.04151091,
|
|
-0.030063879, 0.008801774, -0.023021035, -0.019558564,
|
|
0.05158114, -0.010947698, -0.011825728, 0.0075720972,
|
|
0.0699727, -0.0039981045, 0.069350146, 0.08799282,
|
|
0.016156472, 0.035502106, 0.11695009, 0.006217345,
|
|
0.13392477, -0.037875112, 0.025745004, 0.08940699,
|
|
-0.00924166, 0.0046702605, -0.036598757, -0.08811812,
|
|
0.10522024, -0.032441203, 0.008176899, -0.04454919,
|
|
0.07058152, 0.0067963637, 0.039206743, 0.03259838,
|
|
0.03725492, -0.09515802, 0.013326398, -0.052055415,
|
|
-0.025676316, 0.03198509, -0.015951829, -0.058556724,
|
|
0.036879618, 0.043357447, 0.028362012, -0.05908629,
|
|
0.0059240665, -0.04995891, -0.019187413, 0.0276265,
|
|
-0.01628143, 0.0025863599, 0.08800015, 0.035250366,
|
|
-0.022165963, -0.07328642, -0.009415526, -0.07455109,
|
|
0.11690406, 0.0363299, 0.07411125, 0.042103454,
|
|
-0.009660886, 0.019076364, 0.018299393, -0.046004917,
|
|
0.08891175, 0.0431396, -0.026327137, -0.051502608,
|
|
0.08979574, -0.051670972, 0.04940282, -0.07491107,
|
|
-0.021240504, 0.022596184, -0.034280192, 0.060163025,
|
|
-0.058211457, -0.051837247, -0.01349775, -0.04639988,
|
|
-0.035936575, -0.011681591, 0.064818054, 0.0073146066,
|
|
-0.021745546, -0.043124277, -0.06471268, -0.07053354,
|
|
-0.029321948, -0.05330136, 0.016933719, -0.053782392,
|
|
0.13747959, -0.1361751, -0.11569455, 0.0033329215,
|
|
0.05693899, -0.053219706, 0.063698, 0.07977434,
|
|
-0.07924483, 0.06936997, 0.0034815092, -0.007305279,
|
|
-0.037325785, -0.07251102, -0.033633437, -0.08677009,
|
|
0.091591336, -0.14165086, 0.021752775, 0.019683983,
|
|
0.0011612234, -0.058154266, 0.049996935, 0.0288841,
|
|
-0.0024567875, -0.14345716, 0.010955264, -0.10234828,
|
|
0.1183656, -0.0010731248, -0.023590032, -0.072285876,
|
|
-0.0724771, -0.026382286, -0.0014920527, 0.042667855,
|
|
0.0018776858, 0.02986552, 0.009814309, 0.0733756,
|
|
0.12289186, 0.018043943, -0.0458958, 0.049412545,
|
|
0.033632483, 0.05495232, 0.036686596, -0.013781798,
|
|
-0.010036754, 0.02576849, -0.08307328, 0.010112348,
|
|
0.042521734, -0.05869831, -0.071689695, 0.03876447,
|
|
-0.13275425, -0.0352966, -0.023077697, 0.10285965,
|
|
0.084736146, 0.15568255, -0.00040734606, 0.027835453,
|
|
-0.10292561, -0.032401145, 0.10053256, -0.026142767,
|
|
-0.08271222, -0.0030240538, -0.016368777, 0.1070414,
|
|
0.042672627, 0.013456989, -0.0437609, -0.022309763,
|
|
0.11576483, 0.04108048, 0.061026827, -0.0190714,
|
|
-0.0869359, 0.037901703, 0.0610107, 0.07202949,
|
|
0.01675338, 0.086139716, -0.08795751, -0.014898893,
|
|
-0.023771819, -0.01965048, 0.007955471, -0.043740474,
|
|
0.03346837, -0.10549954, 0.090567775, 0.042013682,
|
|
-0.03176985, 0.12569028, -0.02421228, -0.029526481,
|
|
0.023851605, 0.031539805, 0.05292009, -0.02344001,
|
|
-0.07811758, -0.08834428, 0.10094801, 0.16594367,
|
|
-0.06861939, -0.021256343, -0.041093912, -0.06669611,
|
|
0.035498552, 0.021757556, -0.09302526, -0.015403468,
|
|
-0.06614931, -0.051798206, -0.013874718, 0.03630673,
|
|
0.010412845, -0.08077351, 0.046185967, 0.0035662893,
|
|
0.03541868, -0.094149634, -0.034814864, 0.003128424,
|
|
-0.020674974, -0.03944324, -0.008110165, -0.11113267,
|
|
0.08484226, 0.043586485, 0.040582247, 0.0968012,
|
|
-0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
|
|
0.0326779, 0.041296225, 0.09164146, -0.047743853,
|
|
-0.015952192, -0.034451712, 0.084197424, -0.05347844,
|
|
-0.11768019, 0.085926116, -0.08251791, -0.045081906,
|
|
0.0948852, 0.068401024, 0.024856757, 0.06978981,
|
|
-0.057309967, -0.012775832, -0.0032452994, 0.01977615,
|
|
-0.041040014, -0.024264973, 0.063464895, 0.05431621,
|
|
};
|
|
|
|
cell_to_input_weights_ = {
|
|
0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
|
|
-0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
|
|
-0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
|
|
0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
|
|
|
|
cell_to_forget_weights_ = {
|
|
-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
|
|
-0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
|
|
-0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
|
|
0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
|
|
|
|
cell_to_output_weights_ = {
|
|
0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
|
|
-0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
|
|
-0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
|
|
0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
|
|
|
|
projection_weights_ = {
|
|
-0.009802181, 0.09401916, 0.0717386, -0.13895074,
|
|
0.09641832, 0.060420845, 0.08539281, 0.054285463,
|
|
0.061395317, 0.034448683, -0.042991187, 0.019801661,
|
|
-0.16840284, -0.015726732, -0.23041931, -0.024478018,
|
|
-0.10959692, -0.013875541, 0.18600968, -0.061274476,
|
|
0.0138165, -0.08160894, -0.07661644, 0.032372914,
|
|
0.16169067, 0.22465782, -0.03993472, -0.004017731,
|
|
0.08633481, -0.28869787, 0.08682067, 0.17240396,
|
|
0.014975425, 0.056431185, 0.031037588, 0.16702051,
|
|
0.0077946745, 0.15140012, 0.29405436, 0.120285,
|
|
-0.188994, -0.027265169, 0.043389652, -0.022061434,
|
|
0.014777949, -0.20203483, 0.094781205, 0.19100232,
|
|
0.13987629, -0.036132768, -0.06426278, -0.05108664,
|
|
0.13221376, 0.009441198, -0.16715929, 0.15859416,
|
|
-0.040437475, 0.050779544, -0.022187516, 0.012166504,
|
|
0.027685808, -0.07675938, -0.0055694645, -0.09444123,
|
|
0.0046453946, 0.050794356, 0.10770313, -0.20790008,
|
|
-0.07149004, -0.11425117, 0.008225835, -0.035802525,
|
|
0.14374903, 0.15262283, 0.048710253, 0.1847461,
|
|
-0.007487823, 0.11000021, -0.09542012, 0.22619456,
|
|
-0.029149994, 0.08527916, 0.009043713, 0.0042746216,
|
|
0.016261552, 0.022461696, 0.12689082, -0.043589946,
|
|
-0.12035478, -0.08361797, -0.050666027, -0.1248618,
|
|
-0.1275799, -0.071875185, 0.07377272, 0.09944291,
|
|
-0.18897448, -0.1593054, -0.06526116, -0.040107165,
|
|
-0.004618631, -0.067624845, -0.007576253, 0.10727444,
|
|
0.041546922, -0.20424393, 0.06907816, 0.050412357,
|
|
0.00724631, 0.039827548, 0.12449835, 0.10747581,
|
|
0.13708383, 0.09134148, -0.12617786, -0.06428341,
|
|
0.09956831, 0.1208086, -0.14676677, -0.0727722,
|
|
0.1126304, 0.010139365, 0.015571211, -0.038128063,
|
|
0.022913318, -0.042050496, 0.16842307, -0.060597885,
|
|
0.10531834, -0.06411776, -0.07451711, -0.03410368,
|
|
-0.13393489, 0.06534304, 0.003620307, 0.04490757,
|
|
0.05970546, 0.05197996, 0.02839995, 0.10434969,
|
|
-0.013699693, -0.028353551, -0.07260381, 0.047201227,
|
|
-0.024575593, -0.036445823, 0.07155557, 0.009672501,
|
|
-0.02328883, 0.009533515, -0.03606021, -0.07421458,
|
|
-0.028082801, -0.2678904, -0.13221288, 0.18419984,
|
|
-0.13012612, -0.014588381, -0.035059117, -0.04824723,
|
|
0.07830115, -0.056184657, 0.03277091, 0.025466874,
|
|
0.14494097, -0.12522776, -0.098633975, -0.10766018,
|
|
-0.08317623, 0.08594209, 0.07749552, 0.039474737,
|
|
0.1776665, -0.07409566, -0.0477268, 0.29323658,
|
|
0.10801441, 0.1154011, 0.013952499, 0.10739139,
|
|
0.10708251, -0.051456142, 0.0074137426, -0.10430189,
|
|
0.10034707, 0.045594677, 0.0635285, -0.0715442,
|
|
-0.089667566, -0.10811871, 0.00026344223, 0.08298446,
|
|
-0.009525053, 0.006585689, -0.24567553, -0.09450807,
|
|
0.09648481, 0.026996298, -0.06419476, -0.04752702,
|
|
-0.11063944, -0.23441927, -0.17608605, -0.052156363,
|
|
0.067035615, 0.19271925, -0.0032889997, -0.043264326,
|
|
0.09663576, -0.057112187, -0.10100678, 0.0628376,
|
|
0.04447668, 0.017961001, -0.10094388, -0.10190601,
|
|
0.18335468, 0.10494553, -0.052095775, -0.0026118709,
|
|
0.10539724, -0.04383912, -0.042349473, 0.08438151,
|
|
-0.1947263, 0.02251204, 0.11216432, -0.10307853,
|
|
0.17351969, -0.039091777, 0.08066188, -0.00561982,
|
|
0.12633002, 0.11335965, -0.0088127935, -0.019777594,
|
|
0.06864014, -0.059751723, 0.016233567, -0.06894641,
|
|
-0.28651384, -0.004228674, 0.019708522, -0.16305895,
|
|
-0.07468996, -0.0855457, 0.099339016, -0.07580735,
|
|
-0.13775392, 0.08434318, 0.08330512, -0.12131499,
|
|
0.031935584, 0.09180414, -0.08876437, -0.08049874,
|
|
0.008753825, 0.03498998, 0.030215185, 0.03907079,
|
|
0.089751154, 0.029194152, -0.03337423, -0.019092513,
|
|
0.04331237, 0.04299654, -0.036394123, -0.12915532,
|
|
0.09793732, 0.07512415, -0.11319543, -0.032502122,
|
|
0.15661901, 0.07671967, -0.005491124, -0.19379048,
|
|
-0.218606, 0.21448623, 0.017840758, 0.1416943,
|
|
-0.07051762, 0.19488361, 0.02664691, -0.18104725,
|
|
-0.09334311, 0.15026465, -0.15493552, -0.057762887,
|
|
-0.11604192, -0.262013, -0.01391798, 0.012185008,
|
|
0.11156489, -0.07483202, 0.06693364, -0.26151478,
|
|
0.046425626, 0.036540434, -0.16435726, 0.17338543,
|
|
-0.21401681, -0.11385144, -0.08283257, -0.069031075,
|
|
0.030635102, 0.010969227, 0.11109743, 0.010919218,
|
|
0.027526086, 0.13519906, 0.01891392, -0.046839405,
|
|
-0.040167913, 0.017953383, -0.09700955, 0.0061885654,
|
|
-0.07000971, 0.026893595, -0.038844477, 0.14543656};
|
|
|
|
projection_bias_ = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8,
|
|
0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6};
|
|
|
|
lstm_input_ = {
|
|
{// Batch0: 4 (input_sequence_size) * 5 (n_input)
|
|
0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
|
|
0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
|
|
0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
|
|
0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
|
|
|
|
{// Batch1: 4 (input_sequence_size) * 5 (n_input)
|
|
0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
|
|
0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
|
|
0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
|
|
0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
|
|
};
|
|
|
|
lstm_golden_output_ = {
|
|
{// Batch0: 4 (input_sequence_size) * 16 (n_output)
|
|
0.0960319489, 0.229351997, 0.297207743, 0.415997744, 0.491644233,
|
|
0.578822136, 0.728351235, 0.788540304, 0.909073055, 0.975599587,
|
|
1.08478093, 1.17409372, 1.30914319, 1.4041512, 1.51714694,
|
|
1.61342025, 0.0634541437, 0.190279216, 0.317923307, 0.415168911,
|
|
0.458113253, 0.609743774, 0.731511116, 0.795806408, 0.876155913,
|
|
0.960330188, 1.12396312, 1.22149014, 1.33917773, 1.43213499,
|
|
1.54139447, 1.65451813, 0.0485293195, 0.160991609, 0.337073475,
|
|
0.428976893, 0.459505379, 0.617044866, 0.743735075, 0.790821671,
|
|
0.85271728, 0.946818829, 1.12779701, 1.23345077, 1.35309088,
|
|
1.44595909, 1.56173062, 1.67839324, 0.0445971154, 0.156434938,
|
|
0.341761589, 0.425259203, 0.449760497, 0.633765697, 0.745093822,
|
|
0.791106999, 0.84820503, 0.952787101, 1.13438797, 1.24063754,
|
|
1.34668994, 1.44879568, 1.57038593, 1.67956686},
|
|
{// Batch1: 4 (input_sequence_size) * 16 (n_output)
|
|
0.0861309841, 0.228726774, 0.296653062, 0.40733397, 0.47120741,
|
|
0.581307411, 0.719366193, 0.788456261, 0.904226124, 0.965476751,
|
|
1.10223258, 1.19042683, 1.32106233, 1.41333091, 1.51509535,
|
|
1.62168002, 0.0652779415, 0.18218407, 0.324066937, 0.42611438,
|
|
0.47292757, 0.602282405, 0.739310443, 0.791508496, 0.870626807,
|
|
0.955534995, 1.10976851, 1.21598971, 1.34197009, 1.43256509,
|
|
1.54804492, 1.65581059, 0.0492607877, 0.169714347, 0.332315415,
|
|
0.419173867, 0.44699502, 0.630063772, 0.737177074, 0.792844594,
|
|
0.858417571, 0.956391335, 1.13453305, 1.23976779, 1.34693861,
|
|
1.4410423, 1.55988359, 1.67204297, 0.0390465111, 0.15099439,
|
|
0.3439475, 0.424439192, 0.444207728, 0.632501483, 0.742233515,
|
|
0.791400731, 0.845713973, 0.944575012, 1.14116096, 1.24791968,
|
|
1.35954499, 1.45086145, 1.56633317, 1.68943977}};
|
|
}
|
|
};
|
|
|
|
TEST_F(NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest,
|
|
LstmBlackBoxTest) {
|
|
const int n_batch = 2;
|
|
const int n_input = 5;
|
|
const int n_cell = 20;
|
|
const int n_output = 16;
|
|
const int sequence_length = 4;
|
|
|
|
UnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/true,
|
|
/*use_projection_bias=*/true,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{n_cell, n_input}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{n_cell}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{n_cell}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{n_output, n_cell}, // projection_weight tensor
|
|
{n_output}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
});
|
|
|
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias_);
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
lstm.SetProjectionWeights(projection_weights_);
|
|
lstm.SetProjectionBias(projection_bias_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
|
}
|
|
|
|
class LayerNormUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
|
public:
|
|
LayerNormUnidirectionalLSTMOpModel(
|
|
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
|
|
bool time_major, bool use_cifg, bool use_peephole,
|
|
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
|
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
|
|
const TensorType& weights_type = TensorType_FLOAT32)
|
|
: UnidirectionalLSTMOpModel(
|
|
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
|
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
|
cell_clip, proj_clip, input_shapes, TensorType_FLOAT32, true) {}
|
|
};
|
|
|
|
class BaseLayerNormUnidirectionalLstmTest : public ::testing::Test {
|
|
protected:
|
|
// Weights of the LSTM model. Some are optional.
|
|
std::vector<float> input_to_input_weights_;
|
|
std::vector<float> input_to_cell_weights_;
|
|
std::vector<float> input_to_forget_weights_;
|
|
std::vector<float> input_to_output_weights_;
|
|
std::vector<float> input_gate_bias_;
|
|
std::vector<float> cell_gate_bias_;
|
|
std::vector<float> forget_gate_bias_;
|
|
std::vector<float> output_gate_bias_;
|
|
std::vector<float> recurrent_to_input_weights_;
|
|
std::vector<float> recurrent_to_cell_weights_;
|
|
std::vector<float> recurrent_to_forget_weights_;
|
|
std::vector<float> recurrent_to_output_weights_;
|
|
std::vector<float> cell_to_input_weights_;
|
|
std::vector<float> cell_to_forget_weights_;
|
|
std::vector<float> cell_to_output_weights_;
|
|
std::vector<float> projection_weights_;
|
|
std::vector<float> projection_bias_;
|
|
std::vector<float> input_layer_norm_coefficients_;
|
|
std::vector<float> forget_layer_norm_coefficients_;
|
|
std::vector<float> cell_layer_norm_coefficients_;
|
|
std::vector<float> output_layer_norm_coefficients_;
|
|
|
|
// LSTM input is stored as num_batch x num_inputs vector.
|
|
std::vector<std::vector<float>> lstm_input_;
|
|
// LSTM output is stored as num_batch x num_outputs vector.
|
|
std::vector<std::vector<float>> lstm_golden_output_;
|
|
|
|
// Compares output up to tolerance to the result of the lstm given the input.
|
|
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
|
const std::vector<std::vector<float>>& output,
|
|
UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
|
|
const int num_batches = input.size();
|
|
EXPECT_GT(num_batches, 0);
|
|
const int num_inputs = lstm->num_inputs();
|
|
EXPECT_GT(num_inputs, 0);
|
|
const int input_sequence_size = input[0].size() / num_inputs;
|
|
EXPECT_GT(input_sequence_size, 0);
|
|
// Feed the whole sequence as input.
|
|
for (int i = 0; i < input_sequence_size; ++i) {
|
|
for (int b = 0; b < num_batches; ++b) {
|
|
const float* batch_start = input[b].data() + i * num_inputs;
|
|
const float* batch_end = batch_start + num_inputs;
|
|
|
|
lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start,
|
|
batch_end);
|
|
}
|
|
}
|
|
|
|
lstm->Invoke();
|
|
|
|
const int num_outputs = lstm->num_outputs();
|
|
EXPECT_GT(num_outputs, 0);
|
|
std::vector<float> expected;
|
|
|
|
for (int i = 0; i < input_sequence_size; ++i) {
|
|
for (int b = 0; b < num_batches; ++b) {
|
|
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
|
const float* golden_end_batch = golden_start_batch + num_outputs;
|
|
|
|
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
|
}
|
|
}
|
|
EXPECT_THAT(lstm->GetOutput(),
|
|
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
|
}
|
|
};
|
|
|
|
class CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest
|
|
: public BaseLayerNormUnidirectionalLstmTest {
|
|
void SetUp() override {
|
|
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
|
|
0.05100781, 0.04717243, 0.48944736,
|
|
-0.38535351, -0.17212132};
|
|
|
|
input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
|
|
-0.3633365, -0.22755712, 0.28253698,
|
|
0.24407166, 0.33826375};
|
|
|
|
input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
|
|
-0.09426838, -0.44257352, 0.54939759,
|
|
0.01533556, 0.42751634};
|
|
cell_gate_bias_ = {0., 0., 0., 0.};
|
|
forget_gate_bias_ = {1., 1., 1., 1.};
|
|
output_gate_bias_ = {0., 0., 0., 0.};
|
|
|
|
recurrent_to_cell_weights_ = {
|
|
0.54066205, -0.32668582, -0.43562764, -0.56094903,
|
|
0.42957711, 0.01841056, -0.32764608, -0.33027974,
|
|
-0.10826075, 0.20675004, 0.19069612, -0.03026325,
|
|
-0.54532051, 0.33003211, 0.44901288, 0.21193194};
|
|
|
|
recurrent_to_forget_weights_ = {
|
|
-0.13832897, -0.0515101, -0.2359007, -0.16661474,
|
|
-0.14340827, 0.36986142, 0.23414481, 0.55899,
|
|
0.10798943, -0.41174671, 0.17751795, -0.34484994,
|
|
-0.35874045, -0.11352962, 0.27268326, 0.54058349};
|
|
|
|
recurrent_to_output_weights_ = {
|
|
0.41613156, 0.42610586, -0.16495961, -0.5663873,
|
|
0.30579174, -0.05115908, -0.33941799, 0.23364776,
|
|
0.11178309, 0.09481031, -0.26424935, 0.46261835,
|
|
0.50248802, 0.26114327, -0.43736315, 0.33149987};
|
|
|
|
cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
|
|
0.31544167};
|
|
cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
|
|
-0.77109635};
|
|
|
|
input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
|
|
forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
|
|
cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
|
|
output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
|
|
|
|
lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
|
|
lstm_golden_output_ = {{-0.102089, 0.00653987, 0.0515139, -0.0630045,
|
|
-0.173317, 0.0109206, 0.0903292, -0.109497,
|
|
-0.23827, 0.0119514, 0.119525, -0.12748}};
|
|
}
|
|
};
|
|
|
|
TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest,
|
|
LayerNormLstmBlackBoxTest) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
LayerNormUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{0, 0}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{0, 0}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{0}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
|
|
{0}, // input_layer_norm_coefficient tensor
|
|
{n_cell}, // forget_layer_norm_coefficient tensor
|
|
{n_cell}, // cell_layer_norm_coefficient tensor
|
|
{n_cell}, // output_layer_norm_coefficient tensor
|
|
});
|
|
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
|
|
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
|
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
|
}
|
|
|
|
TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
|
|
NonLayerNormLstmBlackBoxTest) {
|
|
const int n_batch = 1;
|
|
const int n_input = 2;
|
|
// n_cell and n_output have the same size when there is no projection.
|
|
const int n_cell = 4;
|
|
const int n_output = 4;
|
|
const int sequence_length = 3;
|
|
|
|
LayerNormUnidirectionalLSTMOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length,
|
|
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/false,
|
|
/*use_projection_bias=*/false,
|
|
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
|
{
|
|
{sequence_length, n_batch, n_input}, // input tensor
|
|
|
|
{0, 0}, // input_to_input_weight tensor
|
|
{n_cell, n_input}, // input_to_forget_weight tensor
|
|
{n_cell, n_input}, // input_to_cell_weight tensor
|
|
{n_cell, n_input}, // input_to_output_weight tensor
|
|
|
|
{0, 0}, // recurrent_to_input_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
|
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
|
|
|
{0}, // cell_to_input_weight tensor
|
|
{n_cell}, // cell_to_forget_weight tensor
|
|
{n_cell}, // cell_to_output_weight tensor
|
|
|
|
{0}, // input_gate_bias tensor
|
|
{n_cell}, // forget_gate_bias tensor
|
|
{n_cell}, // cell_gate_bias tensor
|
|
{n_cell}, // output_gate_bias tensor
|
|
|
|
{0, 0}, // projection_weight tensor
|
|
{0}, // projection_bias tensor
|
|
|
|
{n_batch, n_output}, // output_state tensor
|
|
{n_batch, n_cell}, // cell_state tensor
|
|
|
|
{0}, // input_layer_norm_coefficient tensor
|
|
{0}, // forget_layer_norm_coefficient tensor
|
|
{0}, // cell_layer_norm_coefficient tensor
|
|
{0}, // output_layer_norm_coefficient tensor
|
|
});
|
|
|
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
|
|
lstm.SetCellBias(cell_gate_bias_);
|
|
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
lstm.SetOutputGateBias(output_gate_bias_);
|
|
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
|
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
|
}
|
|
|
|
class UnidirectionalSequenceLSTMIntegerOpModel : public SingleOpModel {
|
|
public:
|
|
UnidirectionalSequenceLSTMIntegerOpModel(
|
|
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
|
|
bool time_major, bool use_cifg, bool use_peephole,
|
|
bool use_projection_weights, bool use_projection_bias,
|
|
bool use_layer_norm, bool use_8x8_8_implementation,
|
|
const std::vector<std::pair<float, float>>& ranges,
|
|
const std::vector<std::pair<float, int>>& intermediates,
|
|
bool asymmetric_quantize_inputs = false)
|
|
: n_input_(n_input), n_output_(n_output) {
|
|
input_ = AddInput({TensorType_INT8,
|
|
{sequence_length, n_batch, n_input},
|
|
ranges[0].first,
|
|
ranges[0].second});
|
|
|
|
if (use_cifg) {
|
|
input_to_input_weights_ = AddNullInput();
|
|
} else {
|
|
input_to_input_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_input},
|
|
ranges[1].first,
|
|
ranges[1].second});
|
|
}
|
|
input_to_forget_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_input},
|
|
ranges[2].first,
|
|
ranges[2].second});
|
|
input_to_cell_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_input},
|
|
ranges[3].first,
|
|
ranges[3].second});
|
|
input_to_output_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_input},
|
|
ranges[4].first,
|
|
ranges[4].second});
|
|
|
|
if (use_cifg) {
|
|
recurrent_to_input_weights_ = AddNullInput();
|
|
} else {
|
|
recurrent_to_input_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_output},
|
|
ranges[5].first,
|
|
ranges[5].second});
|
|
}
|
|
recurrent_to_forget_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_output},
|
|
ranges[6].first,
|
|
ranges[6].second});
|
|
recurrent_to_cell_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_output},
|
|
ranges[7].first,
|
|
ranges[7].second});
|
|
recurrent_to_output_weights_ = AddInput({TensorType_INT8,
|
|
{n_cell, n_output},
|
|
ranges[8].first,
|
|
ranges[8].second});
|
|
|
|
if (use_peephole) {
|
|
if (use_cifg) {
|
|
cell_to_input_weights_ = AddNullInput();
|
|
} else {
|
|
cell_to_input_weights_ = AddInput(
|
|
{TensorType_INT16, {n_cell}, ranges[9].first, ranges[9].second});
|
|
}
|
|
cell_to_forget_weights_ = AddInput(
|
|
{TensorType_INT16, {n_cell}, ranges[10].first, ranges[10].second});
|
|
cell_to_output_weights_ = AddInput(
|
|
{TensorType_INT16, {n_cell}, ranges[11].first, ranges[11].second});
|
|
} else {
|
|
cell_to_input_weights_ = AddNullInput();
|
|
cell_to_forget_weights_ = AddNullInput();
|
|
cell_to_output_weights_ = AddNullInput();
|
|
}
|
|
|
|
if (use_cifg) {
|
|
input_gate_bias_ = AddNullInput();
|
|
} else {
|
|
input_gate_bias_ = AddInput(
|
|
{TensorType_INT32, {n_cell}, ranges[12].first, ranges[12].second});
|
|
}
|
|
forget_gate_bias_ = AddInput(
|
|
{TensorType_INT32, {n_cell}, ranges[13].first, ranges[13].second});
|
|
cell_gate_bias_ = AddInput(
|
|
{TensorType_INT32, {n_cell}, ranges[14].first, ranges[14].second});
|
|
output_gate_bias_ = AddInput(
|
|
{TensorType_INT32, {n_cell}, ranges[15].first, ranges[15].second});
|
|
|
|
if (use_projection_weights) {
|
|
projection_weights_ = AddInput({TensorType_INT8,
|
|
{n_output, n_cell},
|
|
ranges[16].first,
|
|
ranges[16].second});
|
|
} else {
|
|
projection_weights_ = AddNullInput();
|
|
}
|
|
if (use_projection_bias) {
|
|
CHECK(use_projection_weights);
|
|
projection_bias_ = AddInput(
|
|
{TensorType_INT32, {n_output}, ranges[17].first, ranges[17].second});
|
|
} else {
|
|
projection_bias_ = AddNullInput();
|
|
}
|
|
|
|
// Adding the 2 state tensors.
|
|
AddVariableInput({TensorType_INT16,
|
|
{n_batch, n_output},
|
|
ranges[18].first,
|
|
ranges[18].second});
|
|
AddVariableInput({TensorType_INT16,
|
|
{n_batch, n_cell},
|
|
ranges[19].first,
|
|
ranges[19].second});
|
|
|
|
// Layer norm weights.
|
|
if (use_layer_norm) {
|
|
if (use_cifg) {
|
|
input_layer_norm_coefficients_ = AddNullInput();
|
|
} else {
|
|
input_layer_norm_coefficients_ = AddInput(
|
|
{TensorType_INT16, {n_cell}, ranges[20].first, ranges[20].second});
|
|
}
|
|
forget_layer_norm_coefficients_ = AddInput(
|
|
{TensorType_INT16, {n_cell}, ranges[21].first, ranges[21].second});
|
|
cell_layer_norm_coefficients_ = AddInput(
|
|
{TensorType_INT16, {n_cell}, ranges[22].first, ranges[22].second});
|
|
output_layer_norm_coefficients_ = AddInput(
|
|
{TensorType_INT16, {n_cell}, ranges[23].first, ranges[23].second});
|
|
}
|
|
|
|
// use_8x8_8_implementation is not supported yet.
|
|
CHECK(!use_8x8_8_implementation);
|
|
EXPECT_EQ(intermediates.size(), 5);
|
|
|
|
for (int i = 0; i < intermediates.size(); ++i) {
|
|
AddIntermediate(TensorType_INT16, {intermediates[i].first},
|
|
{intermediates[i].second});
|
|
}
|
|
|
|
output_ = AddOutput({TensorType_INT8,
|
|
{n_batch, n_output},
|
|
ranges[24].first,
|
|
ranges[24].second});
|
|
|
|
// TODO(b/161825581): Add tests where cell_clip and/or proj_clip is not the
|
|
// default 0.
|
|
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
|
|
BuiltinOptions_UnidirectionalSequenceLSTMOptions,
|
|
CreateUnidirectionalSequenceLSTMOptions(
|
|
builder_, ActivationFunctionType_TANH, /*cell_clip=*/0.0f,
|
|
/*proj_clip=*/0.0f, time_major, asymmetric_quantize_inputs)
|
|
.Union());
|
|
|
|
BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1,
|
|
/*allow_fp32_relax_to_fp16=*/false,
|
|
/*apply_delegate=*/true, /*allocate_and_delegate=*/false);
|
|
}
|
|
|
|
void PerformAllocateAndDelegate() { AllocateAndDelegate(true); }
|
|
|
|
void SetInputToInputWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(input_to_input_weights_, f);
|
|
}
|
|
|
|
void SetInputToForgetWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(input_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetInputToCellWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(input_to_cell_weights_, f);
|
|
}
|
|
|
|
void SetInputToOutputWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(input_to_output_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToInputWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(recurrent_to_input_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToForgetWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(recurrent_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToCellWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(recurrent_to_cell_weights_, f);
|
|
}
|
|
|
|
void SetRecurrentToOutputWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(recurrent_to_output_weights_, f);
|
|
}
|
|
|
|
void SetCellToInputWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int16_t>(cell_to_input_weights_, f);
|
|
}
|
|
|
|
void SetCellToForgetWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int16_t>(cell_to_forget_weights_, f);
|
|
}
|
|
|
|
void SetCellToOutputWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int16_t>(cell_to_output_weights_, f);
|
|
}
|
|
|
|
void SetInputLayerNormCoefficients(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int16_t>(input_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int16_t>(forget_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetCellLayerNormCoefficients(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int16_t>(cell_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int16_t>(output_layer_norm_coefficients_, f);
|
|
}
|
|
|
|
void SetInputGateBias(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int32_t>(input_gate_bias_, f);
|
|
}
|
|
|
|
void SetForgetGateBias(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int32_t>(forget_gate_bias_, f);
|
|
}
|
|
|
|
void SetCellBias(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int32_t>(cell_gate_bias_, f);
|
|
}
|
|
|
|
void SetOutputGateBias(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int32_t>(output_gate_bias_, f);
|
|
}
|
|
|
|
void SetProjectionWeights(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(projection_weights_, f);
|
|
}
|
|
|
|
void SetProjectionBias(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int32_t>(projection_bias_, f);
|
|
}
|
|
|
|
void SetInput(const std::vector<float>& f) {
|
|
QuantizeAndPopulate<int8_t>(input_, f);
|
|
}
|
|
|
|
std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
|
|
|
|
int num_inputs() { return n_input_; }
|
|
int num_outputs() { return n_output_; }
|
|
|
|
protected:
|
|
int input_;
|
|
int input_to_input_weights_;
|
|
int input_to_forget_weights_;
|
|
int input_to_cell_weights_;
|
|
int input_to_output_weights_;
|
|
|
|
int recurrent_to_input_weights_;
|
|
int recurrent_to_forget_weights_;
|
|
int recurrent_to_cell_weights_;
|
|
int recurrent_to_output_weights_;
|
|
|
|
int cell_to_input_weights_;
|
|
int cell_to_forget_weights_;
|
|
int cell_to_output_weights_;
|
|
|
|
int input_layer_norm_coefficients_;
|
|
int forget_layer_norm_coefficients_;
|
|
int cell_layer_norm_coefficients_;
|
|
int output_layer_norm_coefficients_;
|
|
|
|
int input_gate_bias_;
|
|
int forget_gate_bias_;
|
|
int cell_gate_bias_;
|
|
int output_gate_bias_;
|
|
|
|
int projection_weights_;
|
|
int projection_bias_;
|
|
|
|
int output_;
|
|
|
|
int n_input_;
|
|
int n_output_;
|
|
};
|
|
|
|
TEST(IntegerUnidirectionalSequenceLstmOpTest,
|
|
NoCifg_NoPeephole_Projection_LayerNorm) {
|
|
// Hyper parameters.
|
|
const int n_batch = 2;
|
|
const int n_input = 5;
|
|
const int n_cell = 4;
|
|
const int n_output = 3;
|
|
const int sequence_length = 3;
|
|
|
|
// Model related weights.
|
|
const std::vector<float> input_to_input_weights = {
|
|
0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
|
|
-0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
|
|
|
|
const std::vector<float> input_to_forget_weights = {
|
|
-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
|
|
-0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
|
|
|
|
const std::vector<float> input_to_cell_weights = {
|
|
-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
|
|
0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
|
|
|
|
const std::vector<float> input_to_output_weights = {
|
|
-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
|
|
0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
|
|
|
|
const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
|
|
|
|
const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
|
|
|
|
const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
|
|
|
|
const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
|
|
|
|
const std::vector<float> recurrent_to_input_weights = {
|
|
-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
|
|
|
|
const std::vector<float> recurrent_to_cell_weights = {
|
|
-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
|
|
|
|
const std::vector<float> recurrent_to_forget_weights = {
|
|
-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
|
|
|
|
const std::vector<float> recurrent_to_output_weights = {
|
|
0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
|
|
|
|
const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
|
|
const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
|
|
0.3};
|
|
const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
|
|
const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
|
|
0.5};
|
|
|
|
const std::vector<float> projection_weights = {
|
|
-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
|
|
|
// Input ranges.
|
|
const std::vector<std::pair<float, float>> ranges = {
|
|
{-1.0, 127.0 / 128}, // input tensor
|
|
{-1.0, 1.0}, // input_to_input_weight tensor
|
|
{-1.0, 1.0}, // input_to_forget_weight tensor
|
|
{-1.0, 1.0}, // input_to_cell_weight tensor
|
|
{-1.0, 1.0}, // input_to_output_weight tensor
|
|
|
|
{-1.0, 1.0}, // recurrent_to_input_weight tensor
|
|
{-1.0, 1.0}, // recurrent_to_forget_weight tensor
|
|
{-1.0, 1.0}, // recurrent_to_cell_weight tensor
|
|
{-1.0, 1.0}, // recurrent_to_output_weight tensor
|
|
|
|
{-1, 1}, // cell_to_input_weight tensor
|
|
{-1, 1}, // cell_to_forget_weight tensor
|
|
{-1, 1}, // cell_to_output_weight tensor
|
|
|
|
{-100, 100}, // input_gate_bias tensor
|
|
{-100, 100}, // forget_gate_bias tensor
|
|
{-100, 100}, // cell_gate_bias tensor
|
|
{-100, 100}, // output_gate_bias tensor
|
|
|
|
{-0.5, 0.5}, // projection_weight tensor
|
|
{-1, 1}, // projection_bias tensor
|
|
|
|
{-1.0, 32767.0 / 32768}, // output_state tensor
|
|
{-1, 1}, // cell_state tensor
|
|
|
|
{-1.00001, 1.0}, // input_layer_norm_coefficient tensor
|
|
{-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
|
|
{-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
|
|
{-1.00001, 1.0}, // output_layer_norm_coefficient tensor
|
|
// Output scale is the same as output_state scale and only output_state
|
|
// scale is used in the op, so this is only provided for clarity.
|
|
{-1.0, 32767.0 / 32768}, // output tensor.
|
|
};
|
|
|
|
// The scale and zero point of intermediate tensors.
|
|
std::vector<std::pair<float, int>> intermediates = {
|
|
{0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
|
|
|
|
// Create model.
|
|
UnidirectionalSequenceLSTMIntegerOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true,
|
|
/*use_cifg=*/false, /*use_peephole=*/false,
|
|
/*use_projection_weights=*/true,
|
|
/*use_projection_bias=*/false,
|
|
/*use_layer_norm=*/true,
|
|
/*use_8x8_8_implementation=*/false, ranges, intermediates);
|
|
// Do allocate.
|
|
lstm.PerformAllocateAndDelegate();
|
|
|
|
// Set weights.
|
|
lstm.SetInputToInputWeights(input_to_input_weights);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias);
|
|
lstm.SetCellBias(cell_gate_bias);
|
|
lstm.SetForgetGateBias(forget_gate_bias);
|
|
lstm.SetOutputGateBias(output_gate_bias);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
|
|
|
|
lstm.SetProjectionWeights(projection_weights);
|
|
|
|
lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
|
|
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
|
|
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
|
|
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
|
|
|
|
// Model inputs. sequence -batch - input
|
|
const std::vector<float> lstm_input = {
|
|
0.7, 0.8, 0.1, 0.2, 0.3, //
|
|
0.8, 0.1, 0.2, 0.4, 0.5, //
|
|
0.2, 0.7, 0.7, 0.1, 0.7, //
|
|
0.3, 0.2, 0.9, 0.8, 0.1, //
|
|
0.7, 0.8, 0.1, 0.2, 0.3, //
|
|
0.3, 0.2, 0.9, 0.8, 0.1, //
|
|
};
|
|
|
|
// Expected outputs, n_batch * sequence_length * n_output
|
|
const std::vector<int8_t> expected_output = {
|
|
127, 127, -108, -67, 127, 127, -128, 127, 127,
|
|
-128, 127, 127, 127, 127, 127, -128, 127, 127,
|
|
};
|
|
|
|
// Invoke and verify the result.
|
|
lstm.SetInput(lstm_input);
|
|
lstm.Invoke();
|
|
EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output));
|
|
}
|
|
|
|
TEST(IntegerUnidirectionalSequenceLstmOpTest,
|
|
NoCifg_Peephole_Projection_LayerNorm) {
|
|
// Hyper parameters.
|
|
const int n_batch = 2;
|
|
const int n_input = 5;
|
|
const int n_cell = 4;
|
|
const int n_output = 3;
|
|
const int sequence_length = 3;
|
|
|
|
// Model related weights.
|
|
const std::vector<float> input_to_input_weights = {
|
|
0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
|
|
-0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
|
|
|
|
const std::vector<float> input_to_forget_weights = {
|
|
-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
|
|
-0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
|
|
|
|
const std::vector<float> input_to_cell_weights = {
|
|
-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
|
|
0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
|
|
|
|
const std::vector<float> input_to_output_weights = {
|
|
-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
|
|
0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
|
|
|
|
const std::vector<float> input_gate_bias = {0.03, 0.15, 0.22, 0.38};
|
|
|
|
const std::vector<float> forget_gate_bias = {0.1, -0.3, -0.2, 0.1};
|
|
|
|
const std::vector<float> cell_gate_bias = {-0.05, 0.72, 0.25, 0.08};
|
|
|
|
const std::vector<float> output_gate_bias = {0.05, -0.01, 0.2, 0.1};
|
|
|
|
const std::vector<float> recurrent_to_input_weights = {
|
|
-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
|
|
|
|
const std::vector<float> recurrent_to_cell_weights = {
|
|
-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
|
|
|
|
const std::vector<float> recurrent_to_forget_weights = {
|
|
-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
|
|
|
|
const std::vector<float> recurrent_to_output_weights = {
|
|
0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
|
|
|
|
const std::vector<float> cell_to_input_weights = {0.3, -0.1, 0.1, -0.2};
|
|
|
|
const std::vector<float> cell_to_forget_weights = {0.2, -0.1, 0.1, -0.2};
|
|
|
|
const std::vector<float> cell_to_output_weights = {0.3, -0.1, 0.1, -0.3};
|
|
|
|
const std::vector<float> input_layer_norm_coefficients = {0.1, 0.2, 0.3, 0.5};
|
|
const std::vector<float> forget_layer_norm_coefficients = {0.2, 0.2, 0.4,
|
|
0.3};
|
|
const std::vector<float> cell_layer_norm_coefficients = {0.7, 0.2, 0.3, 0.8};
|
|
const std::vector<float> output_layer_norm_coefficients = {0.6, 0.2, 0.2,
|
|
0.5};
|
|
|
|
const std::vector<float> projection_weights = {
|
|
-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
|
|
|
// Input ranges.
|
|
const std::vector<std::pair<float, float>> ranges = {
|
|
{-1.0, 127.0 / 128}, // input tensor
|
|
{-1.0, 1.0}, // input_to_input_weight tensor
|
|
{-1.0, 1.0}, // input_to_forget_weight tensor
|
|
{-1.0, 1.0}, // input_to_cell_weight tensor
|
|
{-1.0, 1.0}, // input_to_output_weight tensor
|
|
|
|
{-1.0, 1.0}, // recurrent_to_input_weight tensor
|
|
{-0.9, 0.9}, // recurrent_to_forget_weight tensor
|
|
{-1.0, 1.0}, // recurrent_to_cell_weight tensor
|
|
{-1.0, 1.0}, // recurrent_to_output_weight tensor
|
|
|
|
{-0.3, 0.3}, // cell_to_input_weight tensor
|
|
{-0.3, 0.3}, // cell_to_forget_weight tensor
|
|
{-0.3, 0.3}, // cell_to_output_weight tensor
|
|
|
|
{-100, 100}, // input_gate_bias tensor
|
|
{-100, 80}, // forget_gate_bias tensor
|
|
{-100, 100}, // cell_gate_bias tensor
|
|
{-100, 100}, // output_gate_bias tensor
|
|
|
|
{-0.5, 0.5}, // projection_weight tensor
|
|
{-1, 1}, // projection_bias tensor
|
|
|
|
{-1.0, 32767.0 / 32768}, // output_state tensor
|
|
{-1, 1}, // cell_state tensor
|
|
|
|
{-0.5, 0.5}, // input_layer_norm_coefficient tensor
|
|
{-0.5, 0.5}, // forget_layer_norm_coefficient tensor
|
|
{-1.0, 1.0}, // cell_layer_norm_coefficient tensor
|
|
{-1.0, 1.0}, // output_layer_norm_coefficient tensor
|
|
// Output scale is the same as output_state scale and only output_state
|
|
// scale is used in the op, so this is only provided for clarity.
|
|
{-1.0, 32767.0 / 32768}, // output tensor.
|
|
};
|
|
|
|
// The scale and zero point of intermediate tensors.
|
|
std::vector<std::pair<float, int>> intermediates = {
|
|
{0.007059, 0}, {0.007812, 0}, {0.007059, 0}, {0.007812, 0}, {0.007, 0}};
|
|
|
|
// Create model.
|
|
UnidirectionalSequenceLSTMIntegerOpModel lstm(
|
|
n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true,
|
|
/*use_cifg=*/false, /*use_peephole=*/true,
|
|
/*use_projection_weights=*/true,
|
|
/*use_projection_bias=*/false,
|
|
/*use_layer_norm=*/true,
|
|
/*use_8x8_8_implementation=*/false, ranges, intermediates);
|
|
|
|
// Do allocate.
|
|
lstm.PerformAllocateAndDelegate();
|
|
|
|
// Set weights.
|
|
lstm.SetInputToInputWeights(input_to_input_weights);
|
|
lstm.SetInputToCellWeights(input_to_cell_weights);
|
|
lstm.SetInputToForgetWeights(input_to_forget_weights);
|
|
lstm.SetInputToOutputWeights(input_to_output_weights);
|
|
|
|
lstm.SetInputGateBias(input_gate_bias);
|
|
lstm.SetCellBias(cell_gate_bias);
|
|
lstm.SetForgetGateBias(forget_gate_bias);
|
|
lstm.SetOutputGateBias(output_gate_bias);
|
|
|
|
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights);
|
|
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights);
|
|
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights);
|
|
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights);
|
|
|
|
lstm.SetCellToInputWeights(cell_to_input_weights);
|
|
lstm.SetCellToForgetWeights(cell_to_forget_weights);
|
|
lstm.SetCellToOutputWeights(cell_to_output_weights);
|
|
|
|
lstm.SetProjectionWeights(projection_weights);
|
|
|
|
lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients);
|
|
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients);
|
|
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients);
|
|
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients);
|
|
|
|
// Model inputs. sequence -batch - input
|
|
const std::vector<float> lstm_input = {
|
|
0.7, 0.8, 0.1, 0.2, 0.3, //
|
|
0.8, 0.1, 0.2, 0.4, 0.5, //
|
|
0.2, 0.7, 0.7, 0.1, 0.7, //
|
|
0.3, 0.2, 0.9, 0.8, 0.1, //
|
|
0.7, 0.8, 0.1, 0.2, 0.3, //
|
|
0.3, 0.2, 0.9, 0.8, 0.1, //
|
|
};
|
|
|
|
// Expected outputs, n_batch * sequence_length * n_output
|
|
const std::vector<int8_t> expected_output = {
|
|
127, 127, -16, -21, 127, 127, 23, 127, 127,
|
|
-128, 127, 127, 127, 127, 127, -128, 127, 127,
|
|
};
|
|
|
|
// Invoke and verify the result.
|
|
lstm.SetInput(lstm_input);
|
|
lstm.Invoke();
|
|
EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(expected_output));
|
|
}
|
|
|
|
#define QUANTIZE_PARAMETER_TEST(test) \
|
|
INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true}));
|
|
|
|
QUANTIZE_PARAMETER_TEST(
|
|
CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
|
|
QUANTIZE_PARAMETER_TEST(
|
|
NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest);
|
|
QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest);
|
|
#undef QUANTIZE_PARAMETER_TEST
|
|
} // namespace
|
|
} // namespace tflite
|