diff --git a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py index 44446fb509d..6fdd16d079c 100644 --- a/tensorflow/lite/experimental/examples/lstm/rnn_cell.py +++ b/tensorflow/lite/experimental/examples/lstm/rnn_cell.py @@ -319,6 +319,8 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): else: bias_initializer = init_ops.zeros_initializer(dtype=self.dtype) + forget_bias_initializer = init_ops.constant_initializer(self._forget_bias) + self.input_to_input_w = add_variable_wrapped( "input_to_input_w", input_weight_shape, weight_initializer, 1, maybe_partitioner) @@ -346,8 +348,9 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): self.input_bias = add_variable_wrapped( "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner) - self.forget_bias = add_variable_wrapped( - "forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner) + self.forget_bias = add_variable_wrapped("forget_bias", bias_shape, + forget_bias_initializer, 13, + maybe_partitioner) self.cell_bias = add_variable_wrapped( "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner) self.output_bias = add_variable_wrapped( @@ -473,12 +476,10 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): # Diagonal connections if self._use_peepholes: c = ( - sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(f + self._w_f_diag * c_prev) * c_prev + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) else: - c = ( - sigmoid(f + self._forget_bias) * c_prev + - sigmoid(i) * self._activation(j)) + c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py index e40a3e5d3c1..e29c7510034 100644 --- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py +++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py @@ -56,9 +56,9 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): def buildLstmLayer(self): return tf.keras.layers.StackedRNNCells([ tf.lite.experimental.nn.TFLiteLSTMCell( - self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"), + self.num_units, use_peepholes=True, forget_bias=1.0, name="rnn1"), tf.lite.experimental.nn.TFLiteLSTMCell( - self.num_units, num_proj=8, forget_bias=0, name="rnn2"), + self.num_units, num_proj=8, forget_bias=1.0, name="rnn2"), tf.lite.experimental.nn.TFLiteLSTMCell( self.num_units // 2, use_peepholes=True, @@ -66,7 +66,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): forget_bias=0, name="rnn3"), tf.lite.experimental.nn.TFLiteLSTMCell( - self.num_units, forget_bias=0, name="rnn4") + self.num_units, forget_bias=1.0, name="rnn4") ]) def buildModel(self, lstm_layer, is_dynamic_rnn):