diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index bf24347c43b..d250af90370 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -2040,6 +2040,9 @@ class RawRNNTest(test.TestCase): inputs_ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) inputs_ta = inputs_ta.unstack(inputs) + # Verify emit shapes may be unknown by feeding a placeholder that + # determines an emit shape. + unknown_dim = array_ops.placeholder(dtype=dtypes.int32) cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) @@ -2047,12 +2050,12 @@ class RawRNNTest(test.TestCase): if cell_output is None: emit_output = (array_ops.zeros( [2, 3], dtype=dtypes.int32), array_ops.zeros( - [1], dtype=dtypes.int64)) + [unknown_dim], dtype=dtypes.int64)) next_state = cell.zero_state(batch_size, dtypes.float32) else: emit_output = (array_ops.ones( [batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones( - [batch_size, 1], dtype=dtypes.int64)) + [batch_size, unknown_dim], dtype=dtypes.int64)) next_state = cell_state elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) finished = math_ops.reduce_all(elements_finished) @@ -2069,7 +2072,7 @@ class RawRNNTest(test.TestCase): self.assertEqual([dtypes.int32, dtypes.int64], [ta.dtype for ta in output_ta]) output = [ta.stack() for ta in output_ta] - output_vals = sess.run(output) + output_vals = sess.run(output, feed_dict={unknown_dim: 1}) self.assertAllEqual( np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0]) self.assertAllEqual( diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index ca72734707e..3c3c18b1c9b 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -984,7 +984,8 @@ def raw_rnn(cell, loop_fn, if emit_structure is not None: flat_emit_structure = nest.flatten(emit_structure) - flat_emit_size = [emit.get_shape() for emit in flat_emit_structure] + flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else + array_ops.shape(emit) for emit in flat_emit_structure] flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] else: emit_structure = cell.output_size