Fixed dimensions of activation and cell state tensors.
State tensors are {batch, size} (2D) and not {batch * size} (1D). PiperOrigin-RevId: 260040019
This commit is contained in:
parent
4ad169057c
commit
a879f9b308
@ -103,9 +103,9 @@ class LSTMOpModel : public SingleOpModel {
|
||||
|
||||
// Adding the 2 input state tensors.
|
||||
input_activation_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
|
||||
input_cell_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
|
||||
|
||||
// Layer norm weights.
|
||||
if (is_layer_norm) {
|
||||
@ -1589,7 +1589,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTesInt8) {
|
||||
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
||||
const int n_batch = 2;
|
||||
const int n_input = 5;
|
||||
const int n_cell = 20;
|
||||
@ -1626,7 +1626,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTesInt8) {
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
TensorType_UINT8);
|
||||
TensorType_INT8);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
@ -1690,7 +1690,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
TensorType_INT8);
|
||||
TensorType_UINT8);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
|
Loading…
Reference in New Issue
Block a user