Bugfix to dynamic_rnn: restore some shape information when possible during loop.
Change: 115717351
This commit is contained in:
parent
1df80e82ac
commit
1827f66ab8
tensorflow/python
@ -174,6 +174,45 @@ class RNNTest(tf.test.TestCase):
|
||||
1.0 * (2 + 1) * np.ones((input_size)))))
|
||||
|
||||
|
||||
class GRUTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._seed = 23489
|
||||
np.random.seed(self._seed)
|
||||
|
||||
def _testDynamic(self, use_gpu):
|
||||
time_steps = 8
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
|
||||
input_values = np.random.randn(time_steps, batch_size, input_size)
|
||||
|
||||
sequence_length = np.random.randint(0, time_steps, size=batch_size)
|
||||
|
||||
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||
concat_inputs = tf.placeholder(
|
||||
tf.float32, shape=(time_steps, batch_size, input_size))
|
||||
|
||||
cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size)
|
||||
|
||||
with tf.variable_scope("dynamic_scope"):
|
||||
outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
|
||||
cell, inputs=concat_inputs, sequence_length=sequence_length,
|
||||
time_major=True, dtype=tf.float32)
|
||||
|
||||
feeds = {concat_inputs: input_values}
|
||||
|
||||
# Initialize
|
||||
tf.initialize_all_variables().run(feed_dict=feeds)
|
||||
|
||||
sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds)
|
||||
|
||||
def testDynamic(self):
|
||||
self._testDynamic(use_gpu=False)
|
||||
self._testDynamic(use_gpu=True)
|
||||
|
||||
|
||||
class LSTMTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -465,6 +465,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
|
||||
input_shape = array_ops.shape(inputs)
|
||||
(time_steps, batch_size, unused_depth) = array_ops.unpack(input_shape, 3)
|
||||
|
||||
inputs_got_shape = inputs.get_shape().with_rank(3)
|
||||
(const_time_steps, const_batch_size, const_depth) = inputs_got_shape.as_list()
|
||||
|
||||
# Prepare dynamic conditional copying of state & output
|
||||
zero_output = array_ops.zeros(
|
||||
array_ops.pack([batch_size, cell.output_size]), inputs.dtype)
|
||||
@ -484,7 +487,20 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
|
||||
input_ta = input_ta.unpack(inputs)
|
||||
|
||||
def _time_step(time, state, output_ta_t):
|
||||
"""Take a time step of the dynamic RNN.
|
||||
|
||||
Args:
|
||||
time: int32 scalar Tensor.
|
||||
state: Vector.
|
||||
output_ta_t: `TensorArray`, the output with existing flow.
|
||||
|
||||
Returns:
|
||||
The tuple (time + 1, new_state, output_ta_t with updated flow).
|
||||
"""
|
||||
|
||||
input_t = input_ta.read(time)
|
||||
# Restore some shape information
|
||||
input_t.set_shape([const_batch_size, const_depth])
|
||||
|
||||
(output, new_state) = _rnn_step(
|
||||
time, sequence_length, min_sequence_length, max_sequence_length,
|
||||
@ -501,5 +517,8 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
|
||||
parallel_iterations=parallel_iterations)
|
||||
|
||||
final_outputs = output_final_ta.pack()
|
||||
# Restore some shape information
|
||||
final_outputs.set_shape([
|
||||
const_time_steps, const_batch_size, cell.output_size])
|
||||
|
||||
return (final_outputs, final_state)
|
||||
|
Loading…
Reference in New Issue
Block a user