diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 7b46db9ad00..581a09f5b98 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -174,6 +174,45 @@ class RNNTest(tf.test.TestCase): 1.0 * (2 + 1) * np.ones((input_size))))) +class GRUTest(tf.test.TestCase): + + def setUp(self): + self._seed = 23489 + np.random.seed(self._seed) + + def _testDynamic(self, use_gpu): + time_steps = 8 + num_units = 3 + input_size = 5 + batch_size = 2 + + input_values = np.random.randn(time_steps, batch_size, input_size) + + sequence_length = np.random.randint(0, time_steps, size=batch_size) + + with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess: + concat_inputs = tf.placeholder( + tf.float32, shape=(time_steps, batch_size, input_size)) + + cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size) + + with tf.variable_scope("dynamic_scope"): + outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn( + cell, inputs=concat_inputs, sequence_length=sequence_length, + time_major=True, dtype=tf.float32) + + feeds = {concat_inputs: input_values} + + # Initialize + tf.initialize_all_variables().run(feed_dict=feeds) + + sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds) + + def testDynamic(self): + self._testDynamic(use_gpu=False) + self._testDynamic(use_gpu=True) + + class LSTMTest(tf.test.TestCase): def setUp(self): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 1d5d63dcaab..2144e37fdc7 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -465,6 +465,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, input_shape = array_ops.shape(inputs) (time_steps, batch_size, unused_depth) = array_ops.unpack(input_shape, 3) + inputs_got_shape = inputs.get_shape().with_rank(3) + (const_time_steps, const_batch_size, const_depth) = inputs_got_shape.as_list() + # Prepare dynamic conditional copying of state & output zero_output = array_ops.zeros( array_ops.pack([batch_size, cell.output_size]), inputs.dtype) @@ -484,7 +487,20 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, input_ta = input_ta.unpack(inputs) def _time_step(time, state, output_ta_t): + """Take a time step of the dynamic RNN. + + Args: + time: int32 scalar Tensor. + state: Vector. + output_ta_t: `TensorArray`, the output with existing flow. + + Returns: + The tuple (time + 1, new_state, output_ta_t with updated flow). + """ + input_t = input_ta.read(time) + # Restore some shape information + input_t.set_shape([const_batch_size, const_depth]) (output, new_state) = _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, @@ -501,5 +517,8 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length, parallel_iterations=parallel_iterations) final_outputs = output_final_ta.pack() + # Restore some shape information + final_outputs.set_shape([ + const_time_steps, const_batch_size, cell.output_size]) return (final_outputs, final_state)