Bugfix to dynamic_rnn: restore some shape information when possible during loop.
Change: 115717351
This commit is contained in:
parent
1df80e82ac
commit
1827f66ab8
@ -174,6 +174,45 @@ class RNNTest(tf.test.TestCase):
|
|||||||
1.0 * (2 + 1) * np.ones((input_size)))))
|
1.0 * (2 + 1) * np.ones((input_size)))))
|
||||||
|
|
||||||
|
|
||||||
|
class GRUTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._seed = 23489
|
||||||
|
np.random.seed(self._seed)
|
||||||
|
|
||||||
|
def _testDynamic(self, use_gpu):
|
||||||
|
time_steps = 8
|
||||||
|
num_units = 3
|
||||||
|
input_size = 5
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
input_values = np.random.randn(time_steps, batch_size, input_size)
|
||||||
|
|
||||||
|
sequence_length = np.random.randint(0, time_steps, size=batch_size)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
|
||||||
|
concat_inputs = tf.placeholder(
|
||||||
|
tf.float32, shape=(time_steps, batch_size, input_size))
|
||||||
|
|
||||||
|
cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size)
|
||||||
|
|
||||||
|
with tf.variable_scope("dynamic_scope"):
|
||||||
|
outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
|
||||||
|
cell, inputs=concat_inputs, sequence_length=sequence_length,
|
||||||
|
time_major=True, dtype=tf.float32)
|
||||||
|
|
||||||
|
feeds = {concat_inputs: input_values}
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
tf.initialize_all_variables().run(feed_dict=feeds)
|
||||||
|
|
||||||
|
sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds)
|
||||||
|
|
||||||
|
def testDynamic(self):
|
||||||
|
self._testDynamic(use_gpu=False)
|
||||||
|
self._testDynamic(use_gpu=True)
|
||||||
|
|
||||||
|
|
||||||
class LSTMTest(tf.test.TestCase):
|
class LSTMTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -465,6 +465,9 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
|
|||||||
input_shape = array_ops.shape(inputs)
|
input_shape = array_ops.shape(inputs)
|
||||||
(time_steps, batch_size, unused_depth) = array_ops.unpack(input_shape, 3)
|
(time_steps, batch_size, unused_depth) = array_ops.unpack(input_shape, 3)
|
||||||
|
|
||||||
|
inputs_got_shape = inputs.get_shape().with_rank(3)
|
||||||
|
(const_time_steps, const_batch_size, const_depth) = inputs_got_shape.as_list()
|
||||||
|
|
||||||
# Prepare dynamic conditional copying of state & output
|
# Prepare dynamic conditional copying of state & output
|
||||||
zero_output = array_ops.zeros(
|
zero_output = array_ops.zeros(
|
||||||
array_ops.pack([batch_size, cell.output_size]), inputs.dtype)
|
array_ops.pack([batch_size, cell.output_size]), inputs.dtype)
|
||||||
@ -484,7 +487,20 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
|
|||||||
input_ta = input_ta.unpack(inputs)
|
input_ta = input_ta.unpack(inputs)
|
||||||
|
|
||||||
def _time_step(time, state, output_ta_t):
|
def _time_step(time, state, output_ta_t):
|
||||||
|
"""Take a time step of the dynamic RNN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: int32 scalar Tensor.
|
||||||
|
state: Vector.
|
||||||
|
output_ta_t: `TensorArray`, the output with existing flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tuple (time + 1, new_state, output_ta_t with updated flow).
|
||||||
|
"""
|
||||||
|
|
||||||
input_t = input_ta.read(time)
|
input_t = input_ta.read(time)
|
||||||
|
# Restore some shape information
|
||||||
|
input_t.set_shape([const_batch_size, const_depth])
|
||||||
|
|
||||||
(output, new_state) = _rnn_step(
|
(output, new_state) = _rnn_step(
|
||||||
time, sequence_length, min_sequence_length, max_sequence_length,
|
time, sequence_length, min_sequence_length, max_sequence_length,
|
||||||
@ -501,5 +517,8 @@ def _dynamic_rnn_loop(cell, inputs, initial_state, sequence_length,
|
|||||||
parallel_iterations=parallel_iterations)
|
parallel_iterations=parallel_iterations)
|
||||||
|
|
||||||
final_outputs = output_final_ta.pack()
|
final_outputs = output_final_ta.pack()
|
||||||
|
# Restore some shape information
|
||||||
|
final_outputs.set_shape([
|
||||||
|
const_time_steps, const_batch_size, cell.output_size])
|
||||||
|
|
||||||
return (final_outputs, final_state)
|
return (final_outputs, final_state)
|
||||||
|
Loading…
Reference in New Issue
Block a user