From 1ed691459980004f0f9e53bd19d751fa5449d79f Mon Sep 17 00:00:00 2001
From: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Date: Thu, 4 May 2017 09:44:57 -0700
Subject: [PATCH] Group LSTM cell (#9606)

* GLSTM cell from https://openreview.net/forum?id=ByxWXyNFg&noteId=ByxWXyNFg

* Responding to comments on PR#9606

* Update comments according to review.

* More fixes on users' behalf.
---
 .../rnn/python/kernel_tests/rnn_cell_test.py  |  58 ++++++
 tensorflow/contrib/rnn/python/ops/rnn_cell.py | 178 ++++++++++++++++++
 2 files changed, 236 insertions(+)

diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 55fd7e7a51b..31ac89b4a84 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -906,6 +906,64 @@ class RNNCellTest(test.TestCase):
       # States are left untouched
       self.assertAllClose(res[2], res[3])
 
+  def testGLSTMCell(self):
+    # Ensure that G-LSTM matches LSTM when number_of_groups = 1
+    batch_size = 2
+    num_units = 4
+    number_of_groups = 1
+
+    with self.test_session() as sess:
+      with variable_scope.variable_scope(
+          "root1", initializer=init_ops.constant_initializer(0.5)):
+        x = array_ops.ones([batch_size, num_units])
+        # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
+        gcell = rnn_cell.GLSTMCell(num_units=num_units,
+                                   number_of_groups=number_of_groups)
+        cell = core_rnn_cell_impl.LSTMCell(num_units=num_units)
+        self.assertTrue(isinstance(gcell.state_size, tuple))
+        zero_state = gcell.zero_state(batch_size=batch_size,
+                                      dtype=dtypes.float32)
+        gh, gs = gcell(x, zero_state)
+        h, g = cell(x, zero_state)
+
+        sess.run([variables.global_variables_initializer()])
+        glstm_result = sess.run([gh, gs])
+        lstm_result = sess.run([h, g])
+
+        self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5)
+        self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5)
+
+    # Test that G-LSTM subgroup act like corresponding sub-LSTMs
+    batch_size = 2
+    num_units = 4
+    number_of_groups = 2
+
+    with self.test_session() as sess:
+      with variable_scope.variable_scope(
+          "root2", initializer=init_ops.constant_initializer(0.5)):
+        # input for G-LSTM with 2 groups
+        glstm_input = array_ops.ones([batch_size, num_units])
+        gcell = rnn_cell.GLSTMCell(num_units=num_units,
+                                   number_of_groups=number_of_groups)
+        gcell_zero_state = gcell.zero_state(batch_size=batch_size,
+                                            dtype=dtypes.float32)
+        gh, gs = gcell(glstm_input, gcell_zero_state)
+
+        # input for LSTM cell simulating single G-LSTM group
+        lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
+        # note division by number_of_groups. This cell one simulates G-LSTM group
+        cell = core_rnn_cell_impl.LSTMCell(num_units=
+                                           int(num_units / number_of_groups))
+        cell_zero_state = cell.zero_state(batch_size=batch_size,
+                                          dtype=dtypes.float32)
+        h, g = cell(lstm_input, cell_zero_state)
+
+        sess.run([variables.global_variables_initializer()])
+        [gh_res, h_res] = sess.run([gh, h])
+        self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)],
+                            h_res, 1e-5)
+        self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
+                            h_res, 1e-5)
 
 class LayerNormBasicLSTMCellTest(test.TestCase):
 
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index acba77f0e13..df36dd2bf9b 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1926,3 +1926,181 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
     new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
 
     return new_h, new_state
+
+
+class GLSTMCell(core_rnn_cell.RNNCell):
+  """Group LSTM cell (G-LSTM).
+
+  The implementation is based on:
+
+    https://arxiv.org/abs/1703.10722
+
+  O. Kuchaiev and B. Ginsburg
+  "Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
+  """
+
+  def __init__(self, num_units, initializer=None, num_proj=None,
+               number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh,
+               reuse=None):
+    """Initialize the parameters of G-LSTM cell.
+
+    Args:
+      num_units: int, The number of units in the G-LSTM cell
+      initializer: (optional) The initializer to use for the weight and
+        projection matrices.
+      num_proj: (optional) int, The output dimensionality for the projection
+        matrices.  If None, no projection is performed.
+      number_of_groups: (optional) int, number of groups to use.
+        If `number_of_groups` is 1, then it should be equivalent to LSTM cell
+      forget_bias: Biases of the forget gate are initialized by default to 1
+        in order to reduce the scale of forgetting at the beginning of
+        the training.
+      activation: Activation function of the inner states.
+      reuse: (optional) Python boolean describing whether to reuse variables
+        in an existing scope.  If not `True`, and the existing scope already
+        has the given variables, an error is raised.
+
+    Raises:
+      ValueError: If `num_units` or `num_proj` is not divisible by 
+        `number_of_groups`.
+    """
+    super(GLSTMCell, self).__init__(_reuse=reuse)
+    self._num_units = num_units
+    self._initializer = initializer
+    self._num_proj = num_proj
+    self._forget_bias = forget_bias
+    self._activation = activation
+    self._number_of_groups = number_of_groups
+
+    if self._num_units % self._number_of_groups != 0:
+      raise ValueError("num_units must be divisible by number_of_groups")
+    if self._num_proj:
+      if self._num_proj % self._number_of_groups != 0:
+        raise ValueError("num_proj must be divisible by number_of_groups")
+      self._group_shape = [int(self._num_proj / self._number_of_groups),
+                           int(self._num_units / self._number_of_groups)]
+    else:
+      self._group_shape = [int(self._num_units / self._number_of_groups),
+                           int(self._num_units / self._number_of_groups)]
+
+    if num_proj:
+      self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
+      self._output_size = num_proj
+    else:
+      self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
+      self._output_size = num_units
+
+  @property
+  def state_size(self):
+    return self._state_size
+
+  @property
+  def output_size(self):
+    return self._output_size
+
+  def _get_input_for_group(self, inputs, group_id, group_size):
+    """Slices inputs into groups to prepare for processing by cell's groups
+
+    Args:
+      inputs: cell input or it's previous state,
+              a Tensor, 2D, [batch x num_units]
+      group_id: group id, a Scalar, for which to prepare input
+      group_size: size of the group
+
+    Returns:
+      subset of inputs corresponding to group "group_id",
+      a Tensor, 2D, [batch x num_units/number_of_groups]
+    """
+    batch_size = inputs.shape[0].value or array_ops.shape(value)[0]
+    return array_ops.slice(input_=inputs,
+                           begin=[0, group_id * group_size],
+                           size=[batch_size, group_size],
+                           name=("GLSTM_group%d_input_generation" % group_id))
+
+  def call(self, inputs, state):
+    """Run one step of G-LSTM.
+
+    Args:
+      inputs: input Tensor, 2D, [batch x num_units].
+      state: this must be a tuple of state Tensors, both `2-D`,
+      with column sizes `c_state` and `m_state`.
+
+    Returns:
+      A tuple containing:
+
+      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
+        G-LSTM after reading `inputs` when previous state was `state`.
+        Here output_dim is:
+           num_proj if num_proj was set,
+           num_units otherwise.
+      - LSTMStateTuple representing the new state of G-LSTM  cell
+        after reading `inputs` when the previous state was `state`.
+
+    Raises:
+      ValueError: If input size cannot be inferred from inputs via
+        static shape inference.
+    """
+    (c_prev, m_prev) = state
+
+    input_size = inputs.get_shape().with_rank(2)[1]
+    if input_size.value is None:
+      raise ValueError("Couldn't infer input size from inputs.get_shape()[-1]")
+    dtype = inputs.dtype
+    scope = vs.get_variable_scope()
+    with vs.variable_scope(scope, initializer=self._initializer):
+      i_parts = []
+      j_parts = []
+      f_parts = []
+      o_parts = []
+
+      for group_id in range(self._number_of_groups):
+        with vs.variable_scope("group%d" % group_id):
+          x_g_id = array_ops.concat(
+            [self._get_input_for_group(inputs, group_id,
+                                       self._group_shape[0]),
+             self._get_input_for_group(m_prev, group_id,
+                                       self._group_shape[0])], axis=1)
+          R_k = _linear(x_g_id, 4 * self._group_shape[1], bias=False)
+          i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
+
+        i_parts.append(i_k)
+        j_parts.append(j_k)
+        f_parts.append(f_k)
+        o_parts.append(o_k)
+
+      bi = vs.get_variable(name="bias_i",
+                           shape=[self._num_units],
+                           dtype=dtype,
+                           initializer=
+                           init_ops.constant_initializer(0.0, dtype=dtype))
+      bj = vs.get_variable(name="bias_j",
+                           shape=[self._num_units],
+                           dtype=dtype,
+                           initializer=
+                           init_ops.constant_initializer(0.0, dtype=dtype))
+      bf = vs.get_variable(name="bias_f",
+                           shape=[self._num_units],
+                           dtype=dtype,
+                           initializer=
+                           init_ops.constant_initializer(0.0, dtype=dtype))
+      bo = vs.get_variable(name="bias_o",
+                           shape=[self._num_units],
+                           dtype=dtype,
+                           initializer=
+                           init_ops.constant_initializer(0.0, dtype=dtype))
+
+      i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
+      j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
+      f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
+      o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
+
+    c = (math_ops.sigmoid(f + self._forget_bias) * c_prev +
+         math_ops.sigmoid(i) * math_ops.tanh(j))
+    m = math_ops.sigmoid(o) * self._activation(c)
+
+    if self._num_proj is not None:
+      with vs.variable_scope("projection"):
+        m = _linear(m, self._num_proj, bias=False)
+
+    new_state = core_rnn_cell.LSTMStateTuple(c, m)
+    return m, new_state