Add layer norm to unidirectional lstm. With this change, unidirectional lstm takes 24 inputs with or without layer norm. The 20 input case is only kept for backward compatibility.
PiperOrigin-RevId: 248797274
This commit is contained in:
parent
c6e612f099
commit
731cd5e3ce
@ -34,6 +34,13 @@ namespace ops {
|
||||
namespace builtin {
|
||||
namespace unidirectional_sequence_lstm {
|
||||
|
||||
struct OpData {
|
||||
// If the lstm is layer norm.
|
||||
bool is_layer_norm_lstm;
|
||||
// The scratch tensor index.
|
||||
int scratch_tensor_index;
|
||||
};
|
||||
|
||||
// Input Tensors of size {max_time, n_batch, n_input}
|
||||
constexpr int kInputTensor = 0;
|
||||
|
||||
@ -71,6 +78,13 @@ constexpr int kInputActivationStateTensor = 18;
|
||||
// Cell state tensor of size {n_batch, n_cell}
|
||||
constexpr int kInputCellStateTensor = 19;
|
||||
|
||||
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
|
||||
// matrix.
|
||||
constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
|
||||
constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
|
||||
constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
|
||||
constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
|
||||
|
||||
// Output tensors.
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
@ -87,19 +101,21 @@ enum TemporaryTensor {
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* scratch_tensor_index = new int();
|
||||
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
|
||||
return scratch_tensor_index;
|
||||
auto* op_data = new OpData();
|
||||
context->AddTensors(context, kNumTemporaryTensors,
|
||||
&op_data->scratch_tensor_index);
|
||||
return op_data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<int*>(buffer);
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
// Check that input tensor dimensions matches with each other.
|
||||
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TfLiteNode* node, int n_input,
|
||||
int n_output, int n_cell) {
|
||||
int n_output, int n_cell,
|
||||
bool is_layer_norm_lstm) {
|
||||
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
|
||||
// Making sure clipping parameters have valid values.
|
||||
@ -242,6 +258,48 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
((projection_weights != nullptr) || (projection_bias == nullptr));
|
||||
TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
|
||||
|
||||
if (is_layer_norm_lstm) {
|
||||
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
|
||||
context, node, kInputLayerNormCoefficientsTensor);
|
||||
if (use_cifg) {
|
||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
|
||||
} else {
|
||||
TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->type,
|
||||
kTfLiteFloat32);
|
||||
}
|
||||
|
||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->type,
|
||||
kTfLiteFloat32);
|
||||
|
||||
const TfLiteTensor* cell_layer_norm_coefficients =
|
||||
GetInput(context, node, kCellLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->type,
|
||||
kTfLiteFloat32);
|
||||
|
||||
const TfLiteTensor* output_layer_norm_coefficients =
|
||||
GetInput(context, node, kOutputLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->type,
|
||||
kTfLiteFloat32);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -249,11 +307,30 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
// Allocate a temporary scratch tensor. Also check that the sizes of the input
|
||||
// tensors match each other.
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const int scratch_tensor_index = op_data->scratch_tensor_index;
|
||||
|
||||
// Check we have all the inputs and outputs we need.
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
|
||||
bool is_layer_norm_lstm = false;
|
||||
if (node->inputs->size == 24) {
|
||||
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
|
||||
context, node, kForgetLayerNormCoefficientsTensor);
|
||||
if (forget_layer_norm_coefficients == nullptr) {
|
||||
is_layer_norm_lstm = false;
|
||||
} else {
|
||||
is_layer_norm_lstm = true;
|
||||
}
|
||||
} else if (node->inputs->size == 20) {
|
||||
// This is deprecated and is only kept here for backward compatibility.
|
||||
is_layer_norm_lstm = false;
|
||||
} else {
|
||||
context->ReportError(
|
||||
context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
|
||||
node->inputs->size);
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
op_data->is_layer_norm_lstm = is_layer_norm_lstm;
|
||||
|
||||
// Inferring batch size, number of outputs and sequence length and
|
||||
// number of cells from the input tensors.
|
||||
@ -281,8 +358,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Check that input tensor dimensions matches with each other.
|
||||
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
|
||||
n_output, n_cell));
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
CheckInputTensorDimensions(context, node, n_input, n_output,
|
||||
n_cell, is_layer_norm_lstm));
|
||||
|
||||
// Get the pointer to output, activation_state and cell_state buffer tensors.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
@ -310,7 +388,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
} else {
|
||||
node->temporaries = TfLiteIntArrayCreate(1);
|
||||
}
|
||||
node->temporaries->data[0] = *scratch_tensor_index;
|
||||
node->temporaries->data[0] = scratch_tensor_index;
|
||||
|
||||
// Create a scratch buffer tensor.
|
||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
|
||||
@ -336,7 +414,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Allocate temporary tensors to store quantized values of input,
|
||||
// activation_state and cell_state tensors.
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
*scratch_tensor_index + kInputQuantized;
|
||||
scratch_tensor_index + kInputQuantized;
|
||||
TfLiteTensor* input_quantized =
|
||||
GetTemporary(context, node, kInputQuantized);
|
||||
input_quantized->type = input_to_output_weights->type;
|
||||
@ -347,7 +425,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
input_quantized_size));
|
||||
}
|
||||
node->temporaries->data[kOutputStateQuantized] =
|
||||
*scratch_tensor_index + kOutputStateQuantized;
|
||||
scratch_tensor_index + kOutputStateQuantized;
|
||||
TfLiteTensor* activation_state_quantized =
|
||||
GetTemporary(context, node, kOutputStateQuantized);
|
||||
activation_state_quantized->type = input_to_output_weights->type;
|
||||
@ -361,7 +439,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
activation_state_quantized_size));
|
||||
}
|
||||
node->temporaries->data[kCellStateQuantized] =
|
||||
*scratch_tensor_index + kCellStateQuantized;
|
||||
scratch_tensor_index + kCellStateQuantized;
|
||||
TfLiteTensor* cell_state_quantized =
|
||||
GetTemporary(context, node, kCellStateQuantized);
|
||||
cell_state_quantized->type = input_to_output_weights->type;
|
||||
@ -380,7 +458,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// different matrices (which requires multiplying the scaling factors with
|
||||
// the scaling factor of the matrix).
|
||||
node->temporaries->data[kScalingFactors] =
|
||||
*scratch_tensor_index + kScalingFactors;
|
||||
scratch_tensor_index + kScalingFactors;
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, kScalingFactors);
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
@ -393,7 +471,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
scaling_factors_size));
|
||||
}
|
||||
node->temporaries->data[kProductScalingFactors] =
|
||||
*scratch_tensor_index + kProductScalingFactors;
|
||||
scratch_tensor_index + kProductScalingFactors;
|
||||
TfLiteTensor* prod_scaling_factors =
|
||||
GetTemporary(context, node, kProductScalingFactors);
|
||||
prod_scaling_factors->type = kTfLiteFloat32;
|
||||
@ -410,7 +488,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Allocate a temporary tensor to store the recovered cell weights. Since
|
||||
// this is used for diagonal matrices, only need to store n_cell values.
|
||||
node->temporaries->data[kRecoveredCellWeights] =
|
||||
*scratch_tensor_index + kRecoveredCellWeights;
|
||||
scratch_tensor_index + kRecoveredCellWeights;
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, kRecoveredCellWeights);
|
||||
recovered_cell_weights->type = kTfLiteFloat32;
|
||||
@ -432,6 +510,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
|
||||
node->builtin_data);
|
||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
|
||||
const bool time_major = params->time_major;
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
|
||||
@ -481,6 +561,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* cell_state =
|
||||
GetVariableInput(context, node, kInputCellStateTensor);
|
||||
|
||||
const TfLiteTensor* input_layer_norm_coefficients =
|
||||
is_layer_norm_lstm ? GetOptionalInputTensor(
|
||||
context, node, kInputLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||
is_layer_norm_lstm
|
||||
? GetInput(context, node, kForgetLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
const TfLiteTensor* cell_layer_norm_coefficients =
|
||||
is_layer_norm_lstm
|
||||
? GetInput(context, node, kCellLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
const TfLiteTensor* output_layer_norm_coefficients =
|
||||
is_layer_norm_lstm
|
||||
? GetInput(context, node, kOutputLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
// Copy out the LSTM specific params so they can be passed in the function.
|
||||
@ -497,10 +594,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr,
|
||||
input_layer_norm_coefficients, forget_layer_norm_coefficients,
|
||||
cell_layer_norm_coefficients, output_layer_norm_coefficients,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
@ -529,10 +624,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr,
|
||||
input_layer_norm_coefficients, forget_layer_norm_coefficients,
|
||||
cell_layer_norm_coefficients, output_layer_norm_coefficients,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
|
@ -37,7 +37,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||
bool use_projection_bias, float cell_clip,
|
||||
float proj_clip,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weights_type = TensorType_FLOAT32)
|
||||
const TensorType& weights_type = TensorType_FLOAT32,
|
||||
bool is_layer_norm = false)
|
||||
: n_batch_(n_batch),
|
||||
n_input_(n_input),
|
||||
n_cell_(n_cell),
|
||||
@ -108,6 +109,22 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
|
||||
/*is_variable=*/true);
|
||||
|
||||
// Layer norm weights.
|
||||
if (is_layer_norm) {
|
||||
if (use_cifg) {
|
||||
input_layer_norm_coefficients_ = AddNullInput();
|
||||
} else {
|
||||
input_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(20, input_shapes);
|
||||
}
|
||||
forget_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(21, input_shapes);
|
||||
cell_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(22, input_shapes);
|
||||
output_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(23, input_shapes);
|
||||
}
|
||||
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
|
||||
@ -187,6 +204,22 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||
PopulateTensor(projection_bias_, f);
|
||||
}
|
||||
|
||||
void SetInputLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(input_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetForgetLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(forget_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetCellLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(cell_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetOutputLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(output_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetInput(int offset, const float* begin, const float* end) {
|
||||
PopulateTensor(input_, offset, const_cast<float*>(begin),
|
||||
const_cast<float*>(end));
|
||||
@ -227,6 +260,11 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||
int input_activation_state_;
|
||||
int input_cell_state_;
|
||||
|
||||
int input_layer_norm_coefficients_;
|
||||
int forget_layer_norm_coefficients_;
|
||||
int cell_layer_norm_coefficients_;
|
||||
int output_layer_norm_coefficients_;
|
||||
|
||||
int output_;
|
||||
|
||||
int n_batch_;
|
||||
@ -234,6 +272,16 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
|
||||
int n_cell_;
|
||||
int n_output_;
|
||||
int sequence_length_;
|
||||
|
||||
private:
|
||||
int AddLayerNormCoeffsTensor(
|
||||
int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
|
||||
if (input_shapes[tensor_index][0] != 0) {
|
||||
return AddInput(TensorType_FLOAT32);
|
||||
} else {
|
||||
return AddNullInput();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// The hybrid model has quantized weights.
|
||||
@ -2403,6 +2451,281 @@ TEST_F(NoCifgPeepholeProjectionAndBiasClippingLstmTest, LstmBlackBoxTest) {
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
class LayerNormUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
|
||||
public:
|
||||
LayerNormUnidirectionalLSTMOpModel(
|
||||
int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
|
||||
bool time_major, bool use_cifg, bool use_peephole,
|
||||
bool use_projection_weights, bool use_projection_bias, float cell_clip,
|
||||
float proj_clip, const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weights_type = TensorType_FLOAT32)
|
||||
: UnidirectionalLSTMOpModel(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length, time_major,
|
||||
use_cifg, use_peephole, use_projection_weights, use_projection_bias,
|
||||
cell_clip, proj_clip, input_shapes, TensorType_FLOAT32, true) {}
|
||||
};
|
||||
|
||||
class BaseLayerNormLstmTest : public ::testing::Test {
|
||||
protected:
|
||||
// Weights of the LSTM model. Some are optional.
|
||||
std::vector<float> input_to_input_weights_;
|
||||
std::vector<float> input_to_cell_weights_;
|
||||
std::vector<float> input_to_forget_weights_;
|
||||
std::vector<float> input_to_output_weights_;
|
||||
std::vector<float> input_gate_bias_;
|
||||
std::vector<float> cell_gate_bias_;
|
||||
std::vector<float> forget_gate_bias_;
|
||||
std::vector<float> output_gate_bias_;
|
||||
std::vector<float> recurrent_to_input_weights_;
|
||||
std::vector<float> recurrent_to_cell_weights_;
|
||||
std::vector<float> recurrent_to_forget_weights_;
|
||||
std::vector<float> recurrent_to_output_weights_;
|
||||
std::vector<float> cell_to_input_weights_;
|
||||
std::vector<float> cell_to_forget_weights_;
|
||||
std::vector<float> cell_to_output_weights_;
|
||||
std::vector<float> projection_weights_;
|
||||
std::vector<float> projection_bias_;
|
||||
std::vector<float> input_layer_norm_coefficients_;
|
||||
std::vector<float> forget_layer_norm_coefficients_;
|
||||
std::vector<float> cell_layer_norm_coefficients_;
|
||||
std::vector<float> output_layer_norm_coefficients_;
|
||||
|
||||
// LSTM input is stored as num_batch x num_inputs vector.
|
||||
std::vector<std::vector<float>> lstm_input_;
|
||||
// LSTM output is stored as num_batch x num_outputs vector.
|
||||
std::vector<std::vector<float>> lstm_golden_output_;
|
||||
|
||||
// Compares output up to tolerance to the result of the lstm given the input.
|
||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||
const std::vector<std::vector<float>>& output,
|
||||
UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
|
||||
const int num_batches = input.size();
|
||||
EXPECT_GT(num_batches, 0);
|
||||
const int num_inputs = lstm->num_inputs();
|
||||
EXPECT_GT(num_inputs, 0);
|
||||
const int input_sequence_size = input[0].size() / num_inputs;
|
||||
EXPECT_GT(input_sequence_size, 0);
|
||||
// Feed the whole sequence as input.
|
||||
for (int i = 0; i < input_sequence_size; ++i) {
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* batch_start = input[b].data() + i * num_inputs;
|
||||
const float* batch_end = batch_start + num_inputs;
|
||||
|
||||
lstm->SetInput(((i * num_batches) + b) * num_inputs, batch_start,
|
||||
batch_end);
|
||||
}
|
||||
}
|
||||
|
||||
lstm->Invoke();
|
||||
|
||||
const int num_outputs = lstm->num_outputs();
|
||||
EXPECT_GT(num_outputs, 0);
|
||||
std::vector<float> expected;
|
||||
|
||||
for (int i = 0; i < input_sequence_size; ++i) {
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
||||
const float* golden_end_batch = golden_start_batch + num_outputs;
|
||||
|
||||
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
||||
}
|
||||
}
|
||||
EXPECT_THAT(lstm->GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
||||
}
|
||||
};
|
||||
|
||||
class CifgPeepholeNoProjectionNoClippingLayerNormLstmTest
|
||||
: public BaseLayerNormLstmTest {
|
||||
void SetUp() override {
|
||||
input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
|
||||
0.05100781, 0.04717243, 0.48944736,
|
||||
-0.38535351, -0.17212132};
|
||||
|
||||
input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
|
||||
-0.3633365, -0.22755712, 0.28253698,
|
||||
0.24407166, 0.33826375};
|
||||
|
||||
input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
|
||||
-0.09426838, -0.44257352, 0.54939759,
|
||||
0.01533556, 0.42751634};
|
||||
cell_gate_bias_ = {0., 0., 0., 0.};
|
||||
forget_gate_bias_ = {1., 1., 1., 1.};
|
||||
output_gate_bias_ = {0., 0., 0., 0.};
|
||||
|
||||
recurrent_to_cell_weights_ = {
|
||||
0.54066205, -0.32668582, -0.43562764, -0.56094903,
|
||||
0.42957711, 0.01841056, -0.32764608, -0.33027974,
|
||||
-0.10826075, 0.20675004, 0.19069612, -0.03026325,
|
||||
-0.54532051, 0.33003211, 0.44901288, 0.21193194};
|
||||
|
||||
recurrent_to_forget_weights_ = {
|
||||
-0.13832897, -0.0515101, -0.2359007, -0.16661474,
|
||||
-0.14340827, 0.36986142, 0.23414481, 0.55899,
|
||||
0.10798943, -0.41174671, 0.17751795, -0.34484994,
|
||||
-0.35874045, -0.11352962, 0.27268326, 0.54058349};
|
||||
|
||||
recurrent_to_output_weights_ = {
|
||||
0.41613156, 0.42610586, -0.16495961, -0.5663873,
|
||||
0.30579174, -0.05115908, -0.33941799, 0.23364776,
|
||||
0.11178309, 0.09481031, -0.26424935, 0.46261835,
|
||||
0.50248802, 0.26114327, -0.43736315, 0.33149987};
|
||||
|
||||
cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
|
||||
0.31544167};
|
||||
cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
|
||||
-0.77109635};
|
||||
|
||||
input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
|
||||
forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
|
||||
cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
|
||||
output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
|
||||
|
||||
lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
|
||||
lstm_golden_output_ = {{-0.102089, 0.00653987, 0.0515139, -0.0630045,
|
||||
-0.173317, 0.0109206, 0.0903292, -0.109497,
|
||||
-0.23827, 0.0119514, 0.119525, -0.12748}};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormLstmTest,
|
||||
LayerNormLstmBlackBoxTest) {
|
||||
const int n_batch = 1;
|
||||
const int n_input = 2;
|
||||
// n_cell and n_output have the same size when there is no projection.
|
||||
const int n_cell = 4;
|
||||
const int n_output = 4;
|
||||
const int sequence_length = 3;
|
||||
|
||||
LayerNormUnidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
||||
/*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false,
|
||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
|
||||
{0, 0}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{0, 0}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{n_cell}, // cell_to_forget_weight tensor
|
||||
{n_cell}, // cell_to_output_weight tensor
|
||||
|
||||
{0}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
{n_cell}, // forget_layer_norm_coefficient tensor
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
});
|
||||
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
lstm.SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
|
||||
lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
lstm.SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest,
|
||||
NonLayerNormLstmBlackBoxTest) {
|
||||
const int n_batch = 1;
|
||||
const int n_input = 2;
|
||||
// n_cell and n_output have the same size when there is no projection.
|
||||
const int n_cell = 4;
|
||||
const int n_output = 4;
|
||||
const int sequence_length = 3;
|
||||
|
||||
LayerNormUnidirectionalLSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output, sequence_length,
|
||||
/*time_major=*/true, /*use_cifg=*/true, /*use_peephole=*/true,
|
||||
/*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false,
|
||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||
{
|
||||
{sequence_length, n_batch, n_input}, // input tensor
|
||||
|
||||
{0, 0}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{0, 0}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{n_cell}, // cell_to_forget_weight tensor
|
||||
{n_cell}, // cell_to_output_weight tensor
|
||||
|
||||
{0}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
{0}, // forget_layer_norm_coefficient tensor
|
||||
{0}, // cell_layer_norm_coefficient tensor
|
||||
{0}, // output_layer_norm_coefficient tensor
|
||||
});
|
||||
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user