Support unknown emit shapes in tf.nn.raw_rnn.

PiperOrigin-RevId: 158308002
This commit is contained in:
RJ Ryan 2017-06-07 13:18:59 -07:00 committed by TensorFlower Gardener
parent edb5fed7fc
commit 85e832201e
2 changed files with 8 additions and 4 deletions

View File

@ -2040,6 +2040,9 @@ class RawRNNTest(test.TestCase):
inputs_ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
inputs_ta = inputs_ta.unstack(inputs)
# Verify emit shapes may be unknown by feeding a placeholder that
# determines an emit shape.
unknown_dim = array_ops.placeholder(dtype=dtypes.int32)
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
@ -2047,12 +2050,12 @@ class RawRNNTest(test.TestCase):
if cell_output is None:
emit_output = (array_ops.zeros(
[2, 3], dtype=dtypes.int32), array_ops.zeros(
[1], dtype=dtypes.int64))
[unknown_dim], dtype=dtypes.int64))
next_state = cell.zero_state(batch_size, dtypes.float32)
else:
emit_output = (array_ops.ones(
[batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones(
[batch_size, 1], dtype=dtypes.int64))
[batch_size, unknown_dim], dtype=dtypes.int64))
next_state = cell_state
elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
finished = math_ops.reduce_all(elements_finished)
@ -2069,7 +2072,7 @@ class RawRNNTest(test.TestCase):
self.assertEqual([dtypes.int32, dtypes.int64],
[ta.dtype for ta in output_ta])
output = [ta.stack() for ta in output_ta]
output_vals = sess.run(output)
output_vals = sess.run(output, feed_dict={unknown_dim: 1})
self.assertAllEqual(
np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0])
self.assertAllEqual(

View File

@ -984,7 +984,8 @@ def raw_rnn(cell, loop_fn,
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
flat_emit_size = [emit.get_shape() for emit in flat_emit_structure]
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size