Bugfix to dynamic_rnn: restore some shape information when possible during loop.

Change: 115717351
This commit is contained in:
Eugene Brevdo 2016-02-26 15:44:32 -08:00 committed by TensorFlower Gardener
parent 1df80e82ac
commit 1827f66ab8
2 changed files with 58 additions and 0 deletions
tensorflow/python
kernel_tests
ops

View File

@ -174,6 +174,45 @@ class RNNTest(tf.test.TestCase):
1.0 * (2 + 1) * np.ones((input_size)))))
class GRUTest(tf.test.TestCase):
def setUp(self):
self._seed = 23489
np.random.seed(self._seed)
def _testDynamic(self, use_gpu):
time_steps = 8
num_units = 3
input_size = 5
batch_size = 2
input_values = np.random.randn(time_steps, batch_size, input_size)
sequence_length = np.random.randint(0, time_steps, size=batch_size)
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
concat_inputs = tf.placeholder(
tf.float32, shape=(time_steps, batch_size, input_size))
cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size)
with tf.variable_scope("dynamic_scope"):
outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
cell, inputs=concat_inputs, sequence_length=sequence_length,
time_major=True, dtype=tf.float32)
feeds = {concat_inputs: input_values}
# Initialize
tf.initialize_all_variables().run(feed_dict=feeds)
sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds)
def testDynamic(self):
self._testDynamic(use_gpu=False)
self._testDynamic(use_gpu=True)
class LSTMTest(tf.test.TestCase):
def setUp(self):

View File

@ -465,6 +465,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
input_shape = array_ops.shape(inputs)
(time_steps, batch_size, unused_depth) = array_ops.unpack(input_shape, 3)
inputs_got_shape = inputs.get_shape().with_rank(3)
(const_time_steps, const_batch_size, const_depth) = inputs_got_shape.as_list()
# Prepare dynamic conditional copying of state & output
zero_output = array_ops.zeros(
array_ops.pack([batch_size, cell.output_size]), inputs.dtype)
@ -484,7 +487,20 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
input_ta = input_ta.unpack(inputs)
def _time_step(time, state, output_ta_t):
"""Take a time step of the dynamic RNN.
Args:
time: int32 scalar Tensor.
state: Vector.
output_ta_t: `TensorArray`, the output with existing flow.
Returns:
The tuple (time + 1, new_state, output_ta_t with updated flow).
"""
input_t = input_ta.read(time)
# Restore some shape information
input_t.set_shape([const_batch_size, const_depth])
(output, new_state) = _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
@ -501,5 +517,8 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
parallel_iterations=parallel_iterations)
final_outputs = output_final_ta.pack()
# Restore some shape information
final_outputs.set_shape([
const_time_steps, const_batch_size, cell.output_size])
return (final_outputs, final_state)