Fixed dimensions of activation and cell state tensors.
State tensors are {batch, size} (2D) and not {batch * size} (1D). PiperOrigin-RevId: 260040019
This commit is contained in:
parent
4ad169057c
commit
a879f9b308
@ -103,9 +103,9 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
// Adding the 2 input state tensors.
|
// Adding the 2 input state tensors.
|
||||||
input_activation_state_ =
|
input_activation_state_ =
|
||||||
AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
|
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
|
||||||
input_cell_state_ =
|
input_cell_state_ =
|
||||||
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
|
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
|
||||||
|
|
||||||
// Layer norm weights.
|
// Layer norm weights.
|
||||||
if (is_layer_norm) {
|
if (is_layer_norm) {
|
||||||
@ -1589,7 +1589,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTesInt8) {
|
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
||||||
const int n_batch = 2;
|
const int n_batch = 2;
|
||||||
const int n_input = 5;
|
const int n_input = 5;
|
||||||
const int n_cell = 20;
|
const int n_cell = 20;
|
||||||
@ -1626,7 +1626,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTesInt8) {
|
|||||||
{n_output, n_cell}, // projection_weight tensor
|
{n_output, n_cell}, // projection_weight tensor
|
||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
TensorType_UINT8);
|
TensorType_INT8);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
@ -1690,7 +1690,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
|
|||||||
{n_output, n_cell}, // projection_weight tensor
|
{n_output, n_cell}, // projection_weight tensor
|
||||||
{0}, // projection_bias tensor
|
{0}, // projection_bias tensor
|
||||||
},
|
},
|
||||||
TensorType_INT8);
|
TensorType_UINT8);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||||
|
Loading…
Reference in New Issue
Block a user