Add layer norm to unidirectional lstm. With this change, unidirectional lstm takes 24 inputs with or without layer norm. The 20 input case is only kept for backward compatibility.

PiperOrigin-RevId: 248797274
This commit is contained in:
Jian Li 2019-05-17 15:19:40 -07:00 committed by TensorFlower Gardener
parent c6e612f099
commit 731cd5e3ce
2 changed files with 441 additions and 25 deletions

View File

@ -34,6 +34,13 @@ namespace ops {
namespace builtin {
namespace unidirectional_sequence_lstm {
struct OpData {
// If the lstm is layer norm.
bool is_layer_norm_lstm;
// The scratch tensor index.
int scratch_tensor_index;
};
// Input Tensors of size {max_time, n_batch, n_input}
constexpr int kInputTensor = 0;
@ -71,6 +78,13 @@ constexpr int kInputActivationStateTensor = 18;
// Cell state tensor of size {n_batch, n_cell}
constexpr int kInputCellStateTensor = 19;
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
// matrix.
constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
// Output tensors.
constexpr int kOutputTensor = 0;
@ -87,19 +101,21 @@ enum TemporaryTensor {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int();
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
auto* op_data = new OpData();
context->AddTensors(context, kNumTemporaryTensors,
&op_data->scratch_tensor_index);
return op_data;
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer);
delete reinterpret_cast<OpData*>(buffer);
}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
int n_output, int n_cell,
bool is_layer_norm_lstm) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
@ -242,6 +258,48 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
((projection_weights != nullptr) || (projection_bias == nullptr));
TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
if (is_layer_norm_lstm) {
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
context, node, kInputLayerNormCoefficientsTensor);
if (use_cifg) {
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
} else {
TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->type,
kTfLiteFloat32);
}
const TfLiteTensor* forget_layer_norm_coefficients =
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->type,
kTfLiteFloat32);
const TfLiteTensor* cell_layer_norm_coefficients =
GetInput(context, node, kCellLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->type,
kTfLiteFloat32);
const TfLiteTensor* output_layer_norm_coefficients =
GetInput(context, node, kOutputLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->type,
kTfLiteFloat32);
}
return kTfLiteOk;
}
@ -249,11 +307,30 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
bool is_layer_norm_lstm = false;
if (node->inputs->size == 24) {
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
context, node, kForgetLayerNormCoefficientsTensor);
if (forget_layer_norm_coefficients == nullptr) {
is_layer_norm_lstm = false;
} else {
is_layer_norm_lstm = true;
}
} else if (node->inputs->size == 20) {
// This is deprecated and is only kept here for backward compatibility.
is_layer_norm_lstm = false;
} else {
context->ReportError(
context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
node->inputs->size);
return kTfLiteError;
}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
op_data->is_layer_norm_lstm = is_layer_norm_lstm;
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@ -281,8 +358,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
n_output, n_cell));
TF_LITE_ENSURE_OK(context,
CheckInputTensorDimensions(context, node, n_input, n_output,
n_cell, is_layer_norm_lstm));
// Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@ -310,7 +388,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} else {
node->temporaries = TfLiteIntArrayCreate(1);
}
node->temporaries->data[0] = *scratch_tensor_index;
node->temporaries->data[0] = scratch_tensor_index;
// Create a scratch buffer tensor.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
@ -336,7 +414,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
*scratch_tensor_index + kInputQuantized;
scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
GetTemporary(context, node, kInputQuantized);
input_quantized->type = input_to_output_weights->type;
@ -347,7 +425,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size));
}
node->temporaries->data[kOutputStateQuantized] =
*scratch_tensor_index + kOutputStateQuantized;
scratch_tensor_index + kOutputStateQuantized;
TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, kOutputStateQuantized);
activation_state_quantized->type = input_to_output_weights->type;
@ -361,7 +439,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
activation_state_quantized_size));
}
node->temporaries->data[kCellStateQuantized] =
*scratch_tensor_index + kCellStateQuantized;
scratch_tensor_index + kCellStateQuantized;
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, kCellStateQuantized);
cell_state_quantized->type = input_to_output_weights->type;
@ -380,7 +458,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// different matrices (which requires multiplying the scaling factors with
// the scaling factor of the matrix).
node->temporaries->data[kScalingFactors] =
*scratch_tensor_index + kScalingFactors;
scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32;
@ -393,7 +471,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size));
}
node->temporaries->data[kProductScalingFactors] =
*scratch_tensor_index + kProductScalingFactors;
scratch_tensor_index + kProductScalingFactors;
TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors);
prod_scaling_factors->type = kTfLiteFloat32;
@ -410,7 +488,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Allocate a temporary tensor to store the recovered cell weights. Since
// this is used for diagonal matrices, only need to store n_cell values.
node->temporaries->data[kRecoveredCellWeights] =
*scratch_tensor_index + kRecoveredCellWeights;
scratch_tensor_index + kRecoveredCellWeights;
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
recovered_cell_weights->type = kTfLiteFloat32;
@ -432,6 +510,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data);
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
const bool time_major = params->time_major;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@ -481,6 +561,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
const TfLiteTensor* input_layer_norm_coefficients =
is_layer_norm_lstm ? GetOptionalInputTensor(
context, node, kInputLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* forget_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kForgetLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* cell_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kCellLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* output_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kOutputLayerNormCoefficientsTensor)
: nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Copy out the LSTM specific params so they can be passed in the function.
@ -497,10 +594,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,
@ -529,10 +624,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,

View File

@ -37,7 +37,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
bool use_projection_bias, float cell_clip,
float proj_clip,
const std::vector<std::vector<int>>& input_shapes,
const TensorType& weights_type = TensorType_FLOAT32)
const TensorType& weights_type = TensorType_FLOAT32,
bool is_layer_norm = false)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@ -108,6 +109,22 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
/*is_variable=*/true);
// Layer norm weights.
if (is_layer_norm) {
if (use_cifg) {
input_layer_norm_coefficients_ = AddNullInput();
} else {
input_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(20, input_shapes);
}
forget_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(21, input_shapes);
cell_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(22, input_shapes);
output_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(23, input_shapes);
}
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
@ -187,6 +204,22 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
void SetInputLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(input_layer_norm_coefficients_, f);
}
void SetForgetLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(forget_layer_norm_coefficients_, f);
}
void SetCellLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(cell_layer_norm_coefficients_, f);
}
void SetOutputLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(output_layer_norm_coefficients_, f);
}
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@ -227,6 +260,11 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int input_activation_state_;
int input_cell_state_;
int input_layer_norm_coefficients_;
int forget_layer_norm_coefficients_;
int cell_layer_norm_coefficients_;
int output_layer_norm_coefficients_;
int output_;
int n_batch_;
@ -234,6 +272,16 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int n_cell_;
int n_output_;
int sequence_length_;
private:
int AddLayerNormCoeffsTensor(
int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
if (input_shapes[tensor_index][0] != 0) {
return AddInput(TensorType_FLOAT32);
} else {
return AddNullInput();
}
}
};
// The hybrid model has quantized weights.
@ -2403,6 +2451,281 @@ TEST_F(NoCifgPeepholeProjectionAndBiasClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
class LayerNormUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
public:
LayerNormUnidirectionalLSTMOpModel(
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
bool time_major, bool use_cifg, bool use_peephole,
bool use_projection_weights, bool use_projection_bias, float cell_clip,
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
const TensorType& weights_type = TensorType_FLOAT32)
: UnidirectionalLSTMOpModel(
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
cell_clip, proj_clip, input_shapes, TensorType_FLOAT32, true) {}
};
class BaseLayerNormLstmTest : public ::testing::Test {
protected:
// Weights of the LSTM model. Some are optional.
std::vector<float> input_to_input_weights_;
std::vector<float> input_to_cell_weights_;
std::vector<float> input_to_forget_weights_;
std::vector<float> input_to_output_weights_;
std::vector<float> input_gate_bias_;
std::vector<float> cell_gate_bias_;
std::vector<float> forget_gate_bias_;
std::vector<float> output_gate_bias_;
std::vector<float> recurrent_to_input_weights_;
std::vector<float> recurrent_to_cell_weights_;
std::vector<float> recurrent_to_forget_weights_;
std::vector<float> recurrent_to_output_weights_;
std::vector<float> cell_to_input_weights_;
std::vector<float> cell_to_forget_weights_;
std::vector<float> cell_to_output_weights_;
std::vector<float> projection_weights_;
std::vector<float> projection_bias_;
std::vector<float> input_layer_norm_coefficients_;
std::vector<float> forget_layer_norm_coefficients_;
std::vector<float> cell_layer_norm_coefficients_;
std::vector<float> output_layer_norm_coefficients_;
// LSTM input is stored as num_batch x num_inputs vector.
std::vector<std::vector<float>> lstm_input_;
// LSTM output is stored as num_batch x num_outputs vector.
std::vector<std::vector<float>> lstm_golden_output_;
// Compares output up to tolerance to the result of the lstm given the input.
void VerifyGoldens(const std::vector<std::vector<float>>& input,
const std::vector<std::vector<float>>& output,
UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
const int num_batches = input.size();
EXPECT_GT(num_batches, 0);
const int num_inputs = lstm->num_inputs();
EXPECT_GT(num_inputs, 0);
const int input_sequence_size = input[0].size() / num_inputs;
EXPECT_GT(input_sequence_size, 0);
// Feed the whole sequence as input.
for (int i = 0; i < input_sequence_size; ++i) {
for (int b = 0; b < num_batches; ++b) {
const float* batch_start = input[b].data() + i * num_inputs;
const float* batch_end = batch_start + num_inputs;
lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start,
batch_end);
}
}
lstm->Invoke();
const int num_outputs = lstm->num_outputs();
EXPECT_GT(num_outputs, 0);
std::vector<float> expected;
for (int i = 0; i < input_sequence_size; ++i) {
for (int b = 0; b < num_batches; ++b) {
const float* golden_start_batch = output[b].data() + i * num_outputs;
const float* golden_end_batch = golden_start_batch + num_outputs;
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
}
}
EXPECT_THAT(lstm->GetOutput(),
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
}
};
class CifgPeepholeNoProjectionNoClippingLayerNormLstmTest
: public BaseLayerNormLstmTest {
void SetUp() override {
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
0.05100781, 0.04717243, 0.48944736,
-0.38535351, -0.17212132};
input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
-0.3633365, -0.22755712, 0.28253698,
0.24407166, 0.33826375};
input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
-0.09426838, -0.44257352, 0.54939759,
0.01533556, 0.42751634};
cell_gate_bias_ = {0., 0., 0., 0.};
forget_gate_bias_ = {1., 1., 1., 1.};
output_gate_bias_ = {0., 0., 0., 0.};
recurrent_to_cell_weights_ = {
0.54066205, -0.32668582, -0.43562764, -0.56094903,
0.42957711, 0.01841056, -0.32764608, -0.33027974,
-0.10826075, 0.20675004, 0.19069612, -0.03026325,
-0.54532051, 0.33003211, 0.44901288, 0.21193194};
recurrent_to_forget_weights_ = {
-0.13832897, -0.0515101, -0.2359007, -0.16661474,
-0.14340827, 0.36986142, 0.23414481, 0.55899,
0.10798943, -0.41174671, 0.17751795, -0.34484994,
-0.35874045, -0.11352962, 0.27268326, 0.54058349};
recurrent_to_output_weights_ = {
0.41613156, 0.42610586, -0.16495961, -0.5663873,
0.30579174, -0.05115908, -0.33941799, 0.23364776,
0.11178309, 0.09481031, -0.26424935, 0.46261835,
0.50248802, 0.26114327, -0.43736315, 0.33149987};
cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
0.31544167};
cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
-0.77109635};
input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
lstm_golden_output_ = {{-0.102089, 0.00653987, 0.0515139, -0.0630045,
-0.173317, 0.0109206, 0.0903292, -0.109497,
-0.23827, 0.0119514, 0.119525, -0.12748}};
}
};
TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormLstmTest,
LayerNormLstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
const int n_cell = 4;
const int n_output = 4;
const int sequence_length = 3;
LayerNormUnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length,
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
/*use_projection_weights=*/false,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
{0, 0}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{0, 0}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{0}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
{n_cell}, // forget_layer_norm_coefficient tensor
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_);
lstm.SetInputToOutputWeights(input_to_output_weights_);
lstm.SetCellBias(cell_gate_bias_);
lstm.SetForgetGateBias(forget_gate_bias_);
lstm.SetOutputGateBias(output_gate_bias_);
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
NonLayerNormLstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
const int n_cell = 4;
const int n_output = 4;
const int sequence_length = 3;
LayerNormUnidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length,
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
/*use_projection_weights=*/false,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
{0, 0}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{0, 0}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{0}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
{0}, // forget_layer_norm_coefficient tensor
{0}, // cell_layer_norm_coefficient tensor
{0}, // output_layer_norm_coefficient tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
lstm.SetInputToForgetWeights(input_to_forget_weights_);
lstm.SetInputToOutputWeights(input_to_output_weights_);
lstm.SetCellBias(cell_gate_bias_);
lstm.SetForgetGateBias(forget_gate_bias_);
lstm.SetOutputGateBias(output_gate_bias_);
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
} // namespace
} // namespace tflite