Fixed dimensions of activation and cell state tensors.

State tensors are {batch, size} (2D) and not {batch * size} (1D).

PiperOrigin-RevId: 260040019
This commit is contained in:
A. Unique TensorFlower 2019-07-25 16:05:42 -07:00 committed by TensorFlower Gardener
parent 4ad169057c
commit a879f9b308

View File

@ -103,9 +103,9 @@ class LSTMOpModel : public SingleOpModel {
// Adding the 2 input state tensors. // Adding the 2 input state tensors.
input_activation_state_ = input_activation_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
input_cell_state_ = input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
// Layer norm weights. // Layer norm weights.
if (is_layer_norm) { if (is_layer_norm) {
@ -1589,7 +1589,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
} }
TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTesInt8) { TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
const int n_batch = 2; const int n_batch = 2;
const int n_input = 5; const int n_input = 5;
const int n_cell = 20; const int n_cell = 20;
@ -1626,7 +1626,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTesInt8) {
{n_output, n_cell}, // projection_weight tensor {n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
TensorType_UINT8); TensorType_INT8);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);
@ -1690,7 +1690,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
{n_output, n_cell}, // projection_weight tensor {n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor {0}, // projection_bias tensor
}, },
TensorType_INT8); TensorType_UINT8);
lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToInputWeights(input_to_input_weights_);
lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);