Group LSTM cell (#9606)
* GLSTM cell from https://openreview.net/forum?id=ByxWXyNFg¬eId=ByxWXyNFg * Responding to comments on PR#9606 * Update comments according to review. * More fixes on users' behalf.
This commit is contained in:
parent
d0042ed637
commit
1ed6914599
@ -906,6 +906,64 @@ class RNNCellTest(test.TestCase):
|
|||||||
# States are left untouched
|
# States are left untouched
|
||||||
self.assertAllClose(res[2], res[3])
|
self.assertAllClose(res[2], res[3])
|
||||||
|
|
||||||
|
def testGLSTMCell(self):
|
||||||
|
# Ensure that G-LSTM matches LSTM when number_of_groups = 1
|
||||||
|
batch_size = 2
|
||||||
|
num_units = 4
|
||||||
|
number_of_groups = 1
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
"root1", initializer=init_ops.constant_initializer(0.5)):
|
||||||
|
x = array_ops.ones([batch_size, num_units])
|
||||||
|
# When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
|
||||||
|
gcell = rnn_cell.GLSTMCell(num_units=num_units,
|
||||||
|
number_of_groups=number_of_groups)
|
||||||
|
cell = core_rnn_cell_impl.LSTMCell(num_units=num_units)
|
||||||
|
self.assertTrue(isinstance(gcell.state_size, tuple))
|
||||||
|
zero_state = gcell.zero_state(batch_size=batch_size,
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
gh, gs = gcell(x, zero_state)
|
||||||
|
h, g = cell(x, zero_state)
|
||||||
|
|
||||||
|
sess.run([variables.global_variables_initializer()])
|
||||||
|
glstm_result = sess.run([gh, gs])
|
||||||
|
lstm_result = sess.run([h, g])
|
||||||
|
|
||||||
|
self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5)
|
||||||
|
self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5)
|
||||||
|
|
||||||
|
# Test that G-LSTM subgroup act like corresponding sub-LSTMs
|
||||||
|
batch_size = 2
|
||||||
|
num_units = 4
|
||||||
|
number_of_groups = 2
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
"root2", initializer=init_ops.constant_initializer(0.5)):
|
||||||
|
# input for G-LSTM with 2 groups
|
||||||
|
glstm_input = array_ops.ones([batch_size, num_units])
|
||||||
|
gcell = rnn_cell.GLSTMCell(num_units=num_units,
|
||||||
|
number_of_groups=number_of_groups)
|
||||||
|
gcell_zero_state = gcell.zero_state(batch_size=batch_size,
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
gh, gs = gcell(glstm_input, gcell_zero_state)
|
||||||
|
|
||||||
|
# input for LSTM cell simulating single G-LSTM group
|
||||||
|
lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
|
||||||
|
# note division by number_of_groups. This cell one simulates G-LSTM group
|
||||||
|
cell = core_rnn_cell_impl.LSTMCell(num_units=
|
||||||
|
int(num_units / number_of_groups))
|
||||||
|
cell_zero_state = cell.zero_state(batch_size=batch_size,
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
h, g = cell(lstm_input, cell_zero_state)
|
||||||
|
|
||||||
|
sess.run([variables.global_variables_initializer()])
|
||||||
|
[gh_res, h_res] = sess.run([gh, h])
|
||||||
|
self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)],
|
||||||
|
h_res, 1e-5)
|
||||||
|
self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
|
||||||
|
h_res, 1e-5)
|
||||||
|
|
||||||
class LayerNormBasicLSTMCellTest(test.TestCase):
|
class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -1926,3 +1926,181 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
|||||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||||
|
|
||||||
return new_h, new_state
|
return new_h, new_state
|
||||||
|
|
||||||
|
|
||||||
|
class GLSTMCell(core_rnn_cell.RNNCell):
|
||||||
|
"""Group LSTM cell (G-LSTM).
|
||||||
|
|
||||||
|
The implementation is based on:
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1703.10722
|
||||||
|
|
||||||
|
O. Kuchaiev and B. Ginsburg
|
||||||
|
"Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_units, initializer=None, num_proj=None,
|
||||||
|
number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh,
|
||||||
|
reuse=None):
|
||||||
|
"""Initialize the parameters of G-LSTM cell.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_units: int, The number of units in the G-LSTM cell
|
||||||
|
initializer: (optional) The initializer to use for the weight and
|
||||||
|
projection matrices.
|
||||||
|
num_proj: (optional) int, The output dimensionality for the projection
|
||||||
|
matrices. If None, no projection is performed.
|
||||||
|
number_of_groups: (optional) int, number of groups to use.
|
||||||
|
If `number_of_groups` is 1, then it should be equivalent to LSTM cell
|
||||||
|
forget_bias: Biases of the forget gate are initialized by default to 1
|
||||||
|
in order to reduce the scale of forgetting at the beginning of
|
||||||
|
the training.
|
||||||
|
activation: Activation function of the inner states.
|
||||||
|
reuse: (optional) Python boolean describing whether to reuse variables
|
||||||
|
in an existing scope. If not `True`, and the existing scope already
|
||||||
|
has the given variables, an error is raised.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `num_units` or `num_proj` is not divisible by
|
||||||
|
`number_of_groups`.
|
||||||
|
"""
|
||||||
|
super(GLSTMCell, self).__init__(_reuse=reuse)
|
||||||
|
self._num_units = num_units
|
||||||
|
self._initializer = initializer
|
||||||
|
self._num_proj = num_proj
|
||||||
|
self._forget_bias = forget_bias
|
||||||
|
self._activation = activation
|
||||||
|
self._number_of_groups = number_of_groups
|
||||||
|
|
||||||
|
if self._num_units % self._number_of_groups != 0:
|
||||||
|
raise ValueError("num_units must be divisible by number_of_groups")
|
||||||
|
if self._num_proj:
|
||||||
|
if self._num_proj % self._number_of_groups != 0:
|
||||||
|
raise ValueError("num_proj must be divisible by number_of_groups")
|
||||||
|
self._group_shape = [int(self._num_proj / self._number_of_groups),
|
||||||
|
int(self._num_units / self._number_of_groups)]
|
||||||
|
else:
|
||||||
|
self._group_shape = [int(self._num_units / self._number_of_groups),
|
||||||
|
int(self._num_units / self._number_of_groups)]
|
||||||
|
|
||||||
|
if num_proj:
|
||||||
|
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
|
||||||
|
self._output_size = num_proj
|
||||||
|
else:
|
||||||
|
self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
|
||||||
|
self._output_size = num_units
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
return self._state_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def _get_input_for_group(self, inputs, group_id, group_size):
|
||||||
|
"""Slices inputs into groups to prepare for processing by cell's groups
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: cell input or it's previous state,
|
||||||
|
a Tensor, 2D, [batch x num_units]
|
||||||
|
group_id: group id, a Scalar, for which to prepare input
|
||||||
|
group_size: size of the group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
subset of inputs corresponding to group "group_id",
|
||||||
|
a Tensor, 2D, [batch x num_units/number_of_groups]
|
||||||
|
"""
|
||||||
|
batch_size = inputs.shape[0].value or array_ops.shape(value)[0]
|
||||||
|
return array_ops.slice(input_=inputs,
|
||||||
|
begin=[0, group_id * group_size],
|
||||||
|
size=[batch_size, group_size],
|
||||||
|
name=("GLSTM_group%d_input_generation" % group_id))
|
||||||
|
|
||||||
|
def call(self, inputs, state):
|
||||||
|
"""Run one step of G-LSTM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: input Tensor, 2D, [batch x num_units].
|
||||||
|
state: this must be a tuple of state Tensors, both `2-D`,
|
||||||
|
with column sizes `c_state` and `m_state`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
|
||||||
|
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
|
||||||
|
G-LSTM after reading `inputs` when previous state was `state`.
|
||||||
|
Here output_dim is:
|
||||||
|
num_proj if num_proj was set,
|
||||||
|
num_units otherwise.
|
||||||
|
- LSTMStateTuple representing the new state of G-LSTM cell
|
||||||
|
after reading `inputs` when the previous state was `state`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input size cannot be inferred from inputs via
|
||||||
|
static shape inference.
|
||||||
|
"""
|
||||||
|
(c_prev, m_prev) = state
|
||||||
|
|
||||||
|
input_size = inputs.get_shape().with_rank(2)[1]
|
||||||
|
if input_size.value is None:
|
||||||
|
raise ValueError("Couldn't infer input size from inputs.get_shape()[-1]")
|
||||||
|
dtype = inputs.dtype
|
||||||
|
scope = vs.get_variable_scope()
|
||||||
|
with vs.variable_scope(scope, initializer=self._initializer):
|
||||||
|
i_parts = []
|
||||||
|
j_parts = []
|
||||||
|
f_parts = []
|
||||||
|
o_parts = []
|
||||||
|
|
||||||
|
for group_id in range(self._number_of_groups):
|
||||||
|
with vs.variable_scope("group%d" % group_id):
|
||||||
|
x_g_id = array_ops.concat(
|
||||||
|
[self._get_input_for_group(inputs, group_id,
|
||||||
|
self._group_shape[0]),
|
||||||
|
self._get_input_for_group(m_prev, group_id,
|
||||||
|
self._group_shape[0])], axis=1)
|
||||||
|
R_k = _linear(x_g_id, 4 * self._group_shape[1], bias=False)
|
||||||
|
i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
|
||||||
|
|
||||||
|
i_parts.append(i_k)
|
||||||
|
j_parts.append(j_k)
|
||||||
|
f_parts.append(f_k)
|
||||||
|
o_parts.append(o_k)
|
||||||
|
|
||||||
|
bi = vs.get_variable(name="bias_i",
|
||||||
|
shape=[self._num_units],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=
|
||||||
|
init_ops.constant_initializer(0.0, dtype=dtype))
|
||||||
|
bj = vs.get_variable(name="bias_j",
|
||||||
|
shape=[self._num_units],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=
|
||||||
|
init_ops.constant_initializer(0.0, dtype=dtype))
|
||||||
|
bf = vs.get_variable(name="bias_f",
|
||||||
|
shape=[self._num_units],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=
|
||||||
|
init_ops.constant_initializer(0.0, dtype=dtype))
|
||||||
|
bo = vs.get_variable(name="bias_o",
|
||||||
|
shape=[self._num_units],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=
|
||||||
|
init_ops.constant_initializer(0.0, dtype=dtype))
|
||||||
|
|
||||||
|
i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
|
||||||
|
j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
|
||||||
|
f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
|
||||||
|
o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
|
||||||
|
|
||||||
|
c = (math_ops.sigmoid(f + self._forget_bias) * c_prev +
|
||||||
|
math_ops.sigmoid(i) * math_ops.tanh(j))
|
||||||
|
m = math_ops.sigmoid(o) * self._activation(c)
|
||||||
|
|
||||||
|
if self._num_proj is not None:
|
||||||
|
with vs.variable_scope("projection"):
|
||||||
|
m = _linear(m, self._num_proj, bias=False)
|
||||||
|
|
||||||
|
new_state = core_rnn_cell.LSTMStateTuple(c, m)
|
||||||
|
return m, new_state
|
||||||
|
Loading…
Reference in New Issue
Block a user