diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index ba4933ddf79..c75593e3568 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -142,6 +143,47 @@ class TestStateSaver(object): self.saved_state[name] = state return array_ops.identity(state) + @property + def batch_size(self): + return self._batch_size + + @property + def state_size(self): + return self._state_size + + +class TestStateSaverWithCounters(TestStateSaver): + """Class wrapper around TestStateSaver. + + A dummy class used for testing of static_state_saving_rnn. It helps test if + save_state and state functions got called same number of time when we + evaluate output of rnn cell and state or either of them separately. It + inherits from the TestStateSaver and adds the counters for calls of functions. + """ + + def __init__(self, batch_size, state_size): + super(TestStateSaverWithCounters, self).__init__(batch_size, state_size) + self._num_state_calls = variables_lib.Variable(0) + self._num_save_state_calls = variables_lib.Variable(0) + + def state(self, name): + with ops_lib.control_dependencies( + [state_ops.assign_add(self._num_state_calls, 1)]): + return super(TestStateSaverWithCounters, self).state(name) + + def save_state(self, name, state): + with ops_lib.control_dependencies([state_ops.assign_add( + self._num_save_state_calls, 1)]): + return super(TestStateSaverWithCounters, self).save_state(name, state) + + @property + def num_state_calls(self): + return self._num_state_calls + + @property + def num_save_state_calls(self): + return self._num_save_state_calls + class RNNTest(test.TestCase): @@ -1792,13 +1834,40 @@ class StateSaverRNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _testScope(self, factory, prefix="prefix", use_outer_scope=True): + def _factory(self, scope, state_saver): + num_units = state_saver.state_size // 2 + batch_size = state_saver.batch_size + input_size = 5 + max_length = 8 + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, seed=self._seed) + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=False, + initializer=initializer, + state_is_tuple=False) + inputs = max_length * [ + array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size)) + ] + out, state = rnn.static_state_saving_rnn( + cell, + inputs, + state_saver=state_saver, + state_name="save_lstm", + scope=scope) + return out, state, state_saver + + def _testScope(self, prefix="prefix", use_outer_scope=True): + num_units = 3 + batch_size = 2 + state_saver = TestStateSaver(batch_size, 2 * num_units) + with self.test_session(use_gpu=True, graph=ops_lib.Graph()): if use_outer_scope: with variable_scope.variable_scope(prefix) as scope: - factory(scope) + self._factory(scope=scope, state_saver=state_saver) else: - factory(prefix) + self._factory(scope=prefix, state_saver=state_saver) variables_lib.global_variables_initializer() # check that all the variables names starts @@ -1813,34 +1882,46 @@ class StateSaverRNNTest(test.TestCase): self.assertEqual(len(scope_vars), len(all_vars)) def testStateSaverRNNScope(self): + self._testScope(use_outer_scope=True) + self._testScope(use_outer_scope=False) + self._testScope(prefix=None, use_outer_scope=False) + + def testStateSaverCallsSaveState(self): + """Test that number of calls to state and save_state is equal. + + Test if the order of actual evaluating or skipping evaluation of out, + state tensors, which are the output tensors from static_state_saving_rnn, + have influence on number of calls to save_state and state methods of + state_saver object (the number of calls should be same.) + """ + num_units = 3 - input_size = 5 batch_size = 2 - max_length = 8 + state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units) + out, state, state_saver = self._factory(scope=None, state_saver=state_saver) - def factory(scope): - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=self._seed) - state_saver = TestStateSaver(batch_size, 2 * num_units) - cell = rnn_cell.LSTMCell( - num_units, - use_peepholes=False, - initializer=initializer, - state_is_tuple=False) - inputs = max_length * [ - array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) - ] - return rnn.static_state_saving_rnn( - cell, - inputs, - state_saver=state_saver, - state_name="save_lstm", - scope=scope) + with self.test_session() as sess: + sess.run(variables_lib.global_variables_initializer()) + sess.run(variables_lib.local_variables_initializer()) - self._testScope(factory, use_outer_scope=True) - self._testScope(factory, use_outer_scope=False) - self._testScope(factory, prefix=None, use_outer_scope=False) + _, _, num_state_calls, num_save_state_calls = sess.run([ + out, + state, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) + _, num_state_calls, num_save_state_calls = sess.run([ + out, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) + + _, num_state_calls, num_save_state_calls = sess.run([ + state, + state_saver.num_state_calls, + state_saver.num_save_state_calls]) + self.assertEqual(num_state_calls, num_save_state_calls) class GRUTest(test.TestCase): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index e94ad90dfd7..c77a18d8904 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -1401,6 +1401,13 @@ def static_state_saving_rnn(cell, outputs[-1] = nest.pack_sequence_as( structure=last_output, flat_sequence=flat_last_output) + if state_is_tuple: + state = nest.pack_sequence_as( + structure=state, + flat_sequence=[array_ops.identity(s) for s in flat_state]) + else: + state = array_ops.identity(state) + return (outputs, state)