Fix lstmcell for forget_bias

PiperOrigin-RevId: 236978279
This commit is contained in:
A. Unique TensorFlower 2019-03-05 21:28:35 -08:00 committed by TensorFlower Gardener
parent a808d13bb9
commit dc3900a969
2 changed files with 10 additions and 9 deletions

View File

@ -319,6 +319,8 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
else:
bias_initializer = init_ops.zeros_initializer(dtype=self.dtype)
forget_bias_initializer = init_ops.constant_initializer(self._forget_bias)
self.input_to_input_w = add_variable_wrapped(
"input_to_input_w", input_weight_shape, weight_initializer, 1,
maybe_partitioner)
@ -346,8 +348,9 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
self.input_bias = add_variable_wrapped(
"input_bias", bias_shape, bias_initializer, 12, maybe_partitioner)
self.forget_bias = add_variable_wrapped(
"forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner)
self.forget_bias = add_variable_wrapped("forget_bias", bias_shape,
forget_bias_initializer, 13,
maybe_partitioner)
self.cell_bias = add_variable_wrapped(
"cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner)
self.output_bias = add_variable_wrapped(
@ -473,12 +476,10 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
# Diagonal connections
if self._use_peepholes:
c = (
sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
sigmoid(f + self._w_f_diag * c_prev) * c_prev +
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
else:
c = (
sigmoid(f + self._forget_bias) * c_prev +
sigmoid(i) * self._activation(j))
c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j))
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type

View File

@ -56,9 +56,9 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
def buildLstmLayer(self):
return tf.keras.layers.StackedRNNCells([
tf.lite.experimental.nn.TFLiteLSTMCell(
self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"),
self.num_units, use_peepholes=True, forget_bias=1.0, name="rnn1"),
tf.lite.experimental.nn.TFLiteLSTMCell(
self.num_units, num_proj=8, forget_bias=0, name="rnn2"),
self.num_units, num_proj=8, forget_bias=1.0, name="rnn2"),
tf.lite.experimental.nn.TFLiteLSTMCell(
self.num_units // 2,
use_peepholes=True,
@ -66,7 +66,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
forget_bias=0,
name="rnn3"),
tf.lite.experimental.nn.TFLiteLSTMCell(
self.num_units, forget_bias=0, name="rnn4")
self.num_units, forget_bias=1.0, name="rnn4")
])
def buildModel(self, lstm_layer, is_dynamic_rnn):