diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 55fd7e7a51b..31ac89b4a84 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -906,6 +906,64 @@ class RNNCellTest(test.TestCase): # States are left untouched self.assertAllClose(res[2], res[3]) + def testGLSTMCell(self): + # Ensure that G-LSTM matches LSTM when number_of_groups = 1 + batch_size = 2 + num_units = 4 + number_of_groups = 1 + + with self.test_session() as sess: + with variable_scope.variable_scope( + "root1", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.ones([batch_size, num_units]) + # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM + gcell = rnn_cell.GLSTMCell(num_units=num_units, + number_of_groups=number_of_groups) + cell = core_rnn_cell_impl.LSTMCell(num_units=num_units) + self.assertTrue(isinstance(gcell.state_size, tuple)) + zero_state = gcell.zero_state(batch_size=batch_size, + dtype=dtypes.float32) + gh, gs = gcell(x, zero_state) + h, g = cell(x, zero_state) + + sess.run([variables.global_variables_initializer()]) + glstm_result = sess.run([gh, gs]) + lstm_result = sess.run([h, g]) + + self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5) + self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5) + + # Test that G-LSTM subgroup act like corresponding sub-LSTMs + batch_size = 2 + num_units = 4 + number_of_groups = 2 + + with self.test_session() as sess: + with variable_scope.variable_scope( + "root2", initializer=init_ops.constant_initializer(0.5)): + # input for G-LSTM with 2 groups + glstm_input = array_ops.ones([batch_size, num_units]) + gcell = rnn_cell.GLSTMCell(num_units=num_units, + number_of_groups=number_of_groups) + gcell_zero_state = gcell.zero_state(batch_size=batch_size, + dtype=dtypes.float32) + gh, gs = gcell(glstm_input, gcell_zero_state) + + # input for LSTM cell simulating single G-LSTM group + lstm_input = array_ops.ones([batch_size, num_units / number_of_groups]) + # note division by number_of_groups. This cell one simulates G-LSTM group + cell = core_rnn_cell_impl.LSTMCell(num_units= + int(num_units / number_of_groups)) + cell_zero_state = cell.zero_state(batch_size=batch_size, + dtype=dtypes.float32) + h, g = cell(lstm_input, cell_zero_state) + + sess.run([variables.global_variables_initializer()]) + [gh_res, h_res] = sess.run([gh, h]) + self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)], + h_res, 1e-5) + self.assertAllClose(gh_res[:, int(num_units / number_of_groups):], + h_res, 1e-5) class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index acba77f0e13..df36dd2bf9b 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1926,3 +1926,181 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) return new_h, new_state + + +class GLSTMCell(core_rnn_cell.RNNCell): + """Group LSTM cell (G-LSTM). + + The implementation is based on: + + https://arxiv.org/abs/1703.10722 + + O. Kuchaiev and B. Ginsburg + "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. + """ + + def __init__(self, num_units, initializer=None, num_proj=None, + number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh, + reuse=None): + """Initialize the parameters of G-LSTM cell. + + Args: + num_units: int, The number of units in the G-LSTM cell + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + number_of_groups: (optional) int, number of groups to use. + If `number_of_groups` is 1, then it should be equivalent to LSTM cell + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. + activation: Activation function of the inner states. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already + has the given variables, an error is raised. + + Raises: + ValueError: If `num_units` or `num_proj` is not divisible by + `number_of_groups`. + """ + super(GLSTMCell, self).__init__(_reuse=reuse) + self._num_units = num_units + self._initializer = initializer + self._num_proj = num_proj + self._forget_bias = forget_bias + self._activation = activation + self._number_of_groups = number_of_groups + + if self._num_units % self._number_of_groups != 0: + raise ValueError("num_units must be divisible by number_of_groups") + if self._num_proj: + if self._num_proj % self._number_of_groups != 0: + raise ValueError("num_proj must be divisible by number_of_groups") + self._group_shape = [int(self._num_proj / self._number_of_groups), + int(self._num_units / self._number_of_groups)] + else: + self._group_shape = [int(self._num_units / self._number_of_groups), + int(self._num_units / self._number_of_groups)] + + if num_proj: + self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj) + self._output_size = num_proj + else: + self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def _get_input_for_group(self, inputs, group_id, group_size): + """Slices inputs into groups to prepare for processing by cell's groups + + Args: + inputs: cell input or it's previous state, + a Tensor, 2D, [batch x num_units] + group_id: group id, a Scalar, for which to prepare input + group_size: size of the group + + Returns: + subset of inputs corresponding to group "group_id", + a Tensor, 2D, [batch x num_units/number_of_groups] + """ + batch_size = inputs.shape[0].value or array_ops.shape(value)[0] + return array_ops.slice(input_=inputs, + begin=[0, group_id * group_size], + size=[batch_size, group_size], + name=("GLSTM_group%d_input_generation" % group_id)) + + def call(self, inputs, state): + """Run one step of G-LSTM. + + Args: + inputs: input Tensor, 2D, [batch x num_units]. + state: this must be a tuple of state Tensors, both `2-D`, + with column sizes `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + G-LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - LSTMStateTuple representing the new state of G-LSTM cell + after reading `inputs` when the previous state was `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + (c_prev, m_prev) = state + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Couldn't infer input size from inputs.get_shape()[-1]") + dtype = inputs.dtype + scope = vs.get_variable_scope() + with vs.variable_scope(scope, initializer=self._initializer): + i_parts = [] + j_parts = [] + f_parts = [] + o_parts = [] + + for group_id in range(self._number_of_groups): + with vs.variable_scope("group%d" % group_id): + x_g_id = array_ops.concat( + [self._get_input_for_group(inputs, group_id, + self._group_shape[0]), + self._get_input_for_group(m_prev, group_id, + self._group_shape[0])], axis=1) + R_k = _linear(x_g_id, 4 * self._group_shape[1], bias=False) + i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) + + i_parts.append(i_k) + j_parts.append(j_k) + f_parts.append(f_k) + o_parts.append(o_k) + + bi = vs.get_variable(name="bias_i", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + bj = vs.get_variable(name="bias_j", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + bf = vs.get_variable(name="bias_f", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + bo = vs.get_variable(name="bias_o", + shape=[self._num_units], + dtype=dtype, + initializer= + init_ops.constant_initializer(0.0, dtype=dtype)) + + i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi) + j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj) + f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf) + o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo) + + c = (math_ops.sigmoid(f + self._forget_bias) * c_prev + + math_ops.sigmoid(i) * math_ops.tanh(j)) + m = math_ops.sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + with vs.variable_scope("projection"): + m = _linear(m, self._num_proj, bias=False) + + new_state = core_rnn_cell.LSTMStateTuple(c, m) + return m, new_state