# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for RNN cells."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import numpy as np
import tensorflow as tf

# TODO(ebrevdo): Remove once _linear is fully deprecated.
# pylint: disable=protected-access
from tensorflow.python.ops.rnn_cell import _linear as linear
# pylint: enable=protected-access


class RNNCellTest(tf.test.TestCase):

  def testLinear(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)):
        x = tf.zeros([1, 2])
        l = linear([x], 2, False)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([l], {x.name: np.array([[1., 2.]])})
        self.assertAllClose(res[0], [[3.0, 3.0]])

        # Checks prevent you from accidentally creating a shared function.
        with self.assertRaises(ValueError):
          l1 = linear([x], 2, False)

        # But you can create a new one in a new scope and share the variables.
        with tf.variable_scope("l1") as new_scope:
          l1 = linear([x], 2, False)
        with tf.variable_scope(new_scope, reuse=True):
          linear([l1], 2, False)
        self.assertEqual(len(tf.trainable_variables()), 2)

  def testBasicRNNCell(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m = tf.zeros([1, 2])
        g, _ = tf.nn.rnn_cell.BasicRNNCell(2)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g], {x.name: np.array([[1., 1.]]),
                             m.name: np.array([[0.1, 0.1]])})
        self.assertEqual(res[0].shape, (1, 2))

  def testGRUCell(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m = tf.zeros([1, 2])
        g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g], {x.name: np.array([[1., 1.]]),
                             m.name: np.array([[0.1, 0.1]])})
        # Smoke test
        self.assertAllClose(res[0], [[0.175991, 0.175991]])
      with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 3])  # Test GRUCell with input_size != num_units.
        m = tf.zeros([1, 2])
        g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g], {x.name: np.array([[1., 1., 1.]]),
                             m.name: np.array([[0.1, 0.1]])})
        # Smoke test
        self.assertAllClose(res[0], [[0.156736, 0.156736]])

  def testBasicLSTMCell(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m = tf.zeros([1, 8])
        g, out_m = tf.nn.rnn_cell.MultiRNNCell(
            [tf.nn.rnn_cell.BasicLSTMCell(2)] * 2)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]),
                                    m.name: 0.1 * np.ones([1, 8])})
        self.assertEqual(len(res), 2)
        # The numbers in results were not calculated, this is just a smoke test.
        self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
        expected_mem = np.array([[0.68967271, 0.68967271,
                                  0.44848421, 0.44848421,
                                  0.39897051, 0.39897051,
                                  0.24024698, 0.24024698]])
        self.assertAllClose(res[1], expected_mem)
      with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 3])  # Test BasicLSTMCell with input_size != num_units.
        m = tf.zeros([1, 4])
        g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]),
                                    m.name: 0.1 * np.ones([1, 4])})
        self.assertEqual(len(res), 2)

  def testBasicLSTMCellWithStateTuple(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m0 = tf.zeros([1, 4])
        m1 = tf.zeros([1, 4])
        cell = tf.nn.rnn_cell.MultiRNNCell(
            [tf.nn.rnn_cell.BasicLSTMCell(2)] * 2, state_is_tuple=True)
        g, (out_m0, out_m1) = cell(x, (m0, m1))
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g, out_m0, out_m1],
                       {x.name: np.array([[1., 1.]]),
                        m0.name: 0.1 * np.ones([1, 4]),
                        m1.name: 0.1 * np.ones([1, 4])})
        self.assertEqual(len(res), 3)
        # The numbers in results were not calculated, this is just a smoke test.
        # Note, however, these values should match the original
        # version having state_is_tuple=False.
        self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
        expected_mem0 = np.array([[0.68967271, 0.68967271,
                                   0.44848421, 0.44848421]])
        expected_mem1 = np.array([[0.39897051, 0.39897051,
                                   0.24024698, 0.24024698]])
        self.assertAllClose(res[1], expected_mem0)
        self.assertAllClose(res[2], expected_mem1)

  def testLSTMCell(self):
    with self.test_session() as sess:
      num_units = 8
      num_proj = 6
      state_size = num_units + num_proj
      batch_size = 3
      input_size = 2
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([batch_size, input_size])
        m = tf.zeros([batch_size, state_size])
        output, state = tf.nn.rnn_cell.LSTMCell(
            num_units=num_units, num_proj=num_proj, forget_bias=1.0)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([output, state],
                       {x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
                        m.name: 0.1 * np.ones((batch_size, state_size))})
        self.assertEqual(len(res), 2)
        # The numbers in results were not calculated, this is mostly just a
        # smoke test.
        self.assertEqual(res[0].shape, (batch_size, num_proj))
        self.assertEqual(res[1].shape, (batch_size, state_size))
        # Different inputs so different outputs and states
        for i in range(1, batch_size):
          self.assertTrue(
              float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
          self.assertTrue(
              float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)

  def testOutputProjectionWrapper(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 3])
        m = tf.zeros([1, 3])
        cell = tf.nn.rnn_cell.OutputProjectionWrapper(
            tf.nn.rnn_cell.GRUCell(3), 2)
        g, new_m = cell(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
                                    m.name: np.array([[0.1, 0.1, 0.1]])})
        self.assertEqual(res[1].shape, (1, 3))
        # The numbers in results were not calculated, this is just a smoke test.
        self.assertAllClose(res[0], [[0.231907, 0.231907]])

  def testInputProjectionWrapper(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m = tf.zeros([1, 3])
        cell = tf.nn.rnn_cell.InputProjectionWrapper(
            tf.nn.rnn_cell.GRUCell(3), num_proj=3)
        g, new_m = cell(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g, new_m], {x.name: np.array([[1., 1.]]),
                                    m.name: np.array([[0.1, 0.1, 0.1]])})
        self.assertEqual(res[1].shape, (1, 3))
        # The numbers in results were not calculated, this is just a smoke test.
        self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])

  def testDropoutWrapper(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 3])
        m = tf.zeros([1, 3])
        keep = tf.zeros([]) + 1
        g, new_m = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.GRUCell(3),
                                                 keep, keep)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
                                    m.name: np.array([[0.1, 0.1, 0.1]])})
        self.assertEqual(res[1].shape, (1, 3))
        # The numbers in results were not calculated, this is just a smoke test.
        self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])

  def testEmbeddingWrapper(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 1], dtype=tf.int32)
        m = tf.zeros([1, 2])
        g, new_m = tf.nn.rnn_cell.EmbeddingWrapper(
            tf.nn.rnn_cell.GRUCell(2),
            embedding_classes=3, embedding_size=2)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g, new_m], {x.name: np.array([[1]]),
                                    m.name: np.array([[0.1, 0.1]])})
        self.assertEqual(res[1].shape, (1, 2))
        # The numbers in results were not calculated, this is just a smoke test.
        self.assertAllClose(res[0], [[0.17139, 0.17139]])

  def testMultiRNNCell(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m = tf.zeros([1, 4])
        _, ml = tf.nn.rnn_cell.MultiRNNCell(
            [tf.nn.rnn_cell.GRUCell(2)] * 2)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run(ml, {x.name: np.array([[1., 1.]]),
                            m.name: np.array([[0.1, 0.1, 0.1, 0.1]])})
        # The numbers in results were not calculated, this is just a smoke test.
        self.assertAllClose(res, [[0.175991, 0.175991,
                                   0.13248, 0.13248]])

  def testMultiRNNCellWithStateTuple(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m_bad = tf.zeros([1, 4])
        m_good = (tf.zeros([1, 2]), tf.zeros([1, 2]))

        # Test incorrectness of state
        with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
          tf.nn.rnn_cell.MultiRNNCell(
              [tf.nn.rnn_cell.GRUCell(2)] * 2, state_is_tuple=True)(x, m_bad)

        _, ml = tf.nn.rnn_cell.MultiRNNCell(
            [tf.nn.rnn_cell.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good)

        sess.run([tf.initialize_all_variables()])
        res = sess.run(ml, {x.name: np.array([[1., 1.]]),
                            m_good[0].name: np.array([[0.1, 0.1]]),
                            m_good[1].name: np.array([[0.1, 0.1]])})

        # The numbers in results were not calculated, this is just a
        # smoke test.  However, these numbers should match those of
        # the test testMultiRNNCell.
        self.assertAllClose(res[0], [[0.175991, 0.175991]])
        self.assertAllClose(res[1], [[0.13248, 0.13248]])


class SlimRNNCellTest(tf.test.TestCase):

  def testBasicRNNCell(self):
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        x = tf.zeros([1, 2])
        m = tf.zeros([1, 2])
        my_cell = functools.partial(basic_rnn_cell, num_units=2)
        g, _ = tf.nn.rnn_cell.SlimRNNCell(my_cell)(x, m)
        sess.run([tf.initialize_all_variables()])
        res = sess.run([g], {x.name: np.array([[1., 1.]]),
                             m.name: np.array([[0.1, 0.1]])})
        self.assertEqual(res[0].shape, (1, 2))

  def testBasicRNNCellMatch(self):
    batch_size = 32
    input_size = 100
    num_units = 10
    with self.test_session() as sess:
      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
        inputs = tf.random_uniform((batch_size, input_size))
        _, initial_state = basic_rnn_cell(inputs, None, num_units)
        my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
        slim_cell = tf.nn.rnn_cell.SlimRNNCell(my_cell)
        slim_outputs, slim_state = slim_cell(inputs, initial_state)
        rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)
        outputs, state = rnn_cell(inputs, initial_state)
        self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
        self.assertEqual(slim_state.get_shape(), state.get_shape())
        sess.run([tf.initialize_all_variables()])
        res = sess.run([slim_outputs, slim_state, outputs, state])
        self.assertAllClose(res[0], res[2])
        self.assertAllClose(res[1], res[3])


def basic_rnn_cell(inputs, state, num_units, scope=None):
  if state is None:
    if inputs is not None:
      batch_size = inputs.get_shape()[0]
      dtype = inputs.dtype
    else:
      batch_size = 0
      dtype = tf.float32
    init_output = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
    init_state = tf.zeros(tf.pack([batch_size, num_units]), dtype=dtype)
    init_output.set_shape([batch_size, num_units])
    init_state.set_shape([batch_size, num_units])
    return init_output, init_state
  else:
    with tf.variable_op_scope([inputs, state], scope, "BasicRNNCell"):
      output = tf.tanh(linear([inputs, state],
                              num_units, True))
    return output, output

if __name__ == "__main__":
  tf.test.main()