Fix lstmcell for forget_bias
PiperOrigin-RevId: 236978279
This commit is contained in:
parent
a808d13bb9
commit
dc3900a969
@ -319,6 +319,8 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
||||
else:
|
||||
bias_initializer = init_ops.zeros_initializer(dtype=self.dtype)
|
||||
|
||||
forget_bias_initializer = init_ops.constant_initializer(self._forget_bias)
|
||||
|
||||
self.input_to_input_w = add_variable_wrapped(
|
||||
"input_to_input_w", input_weight_shape, weight_initializer, 1,
|
||||
maybe_partitioner)
|
||||
@ -346,8 +348,9 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
||||
|
||||
self.input_bias = add_variable_wrapped(
|
||||
"input_bias", bias_shape, bias_initializer, 12, maybe_partitioner)
|
||||
self.forget_bias = add_variable_wrapped(
|
||||
"forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner)
|
||||
self.forget_bias = add_variable_wrapped("forget_bias", bias_shape,
|
||||
forget_bias_initializer, 13,
|
||||
maybe_partitioner)
|
||||
self.cell_bias = add_variable_wrapped(
|
||||
"cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner)
|
||||
self.output_bias = add_variable_wrapped(
|
||||
@ -473,12 +476,10 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
||||
# Diagonal connections
|
||||
if self._use_peepholes:
|
||||
c = (
|
||||
sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
|
||||
sigmoid(f + self._w_f_diag * c_prev) * c_prev +
|
||||
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
|
||||
else:
|
||||
c = (
|
||||
sigmoid(f + self._forget_bias) * c_prev +
|
||||
sigmoid(i) * self._activation(j))
|
||||
c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j))
|
||||
|
||||
if self._cell_clip is not None:
|
||||
# pylint: disable=invalid-unary-operand-type
|
||||
|
@ -56,9 +56,9 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
def buildLstmLayer(self):
|
||||
return tf.keras.layers.StackedRNNCells([
|
||||
tf.lite.experimental.nn.TFLiteLSTMCell(
|
||||
self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"),
|
||||
self.num_units, use_peepholes=True, forget_bias=1.0, name="rnn1"),
|
||||
tf.lite.experimental.nn.TFLiteLSTMCell(
|
||||
self.num_units, num_proj=8, forget_bias=0, name="rnn2"),
|
||||
self.num_units, num_proj=8, forget_bias=1.0, name="rnn2"),
|
||||
tf.lite.experimental.nn.TFLiteLSTMCell(
|
||||
self.num_units // 2,
|
||||
use_peepholes=True,
|
||||
@ -66,7 +66,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
forget_bias=0,
|
||||
name="rnn3"),
|
||||
tf.lite.experimental.nn.TFLiteLSTMCell(
|
||||
self.num_units, forget_bias=0, name="rnn4")
|
||||
self.num_units, forget_bias=1.0, name="rnn4")
|
||||
])
|
||||
|
||||
def buildModel(self, lstm_layer, is_dynamic_rnn):
|
||||
|
Loading…
x
Reference in New Issue
Block a user