Support unknown emit shapes in tf.nn.raw_rnn.
PiperOrigin-RevId: 158308002
This commit is contained in:
parent
edb5fed7fc
commit
85e832201e
@ -2040,6 +2040,9 @@ class RawRNNTest(test.TestCase):
|
||||
inputs_ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||
inputs_ta = inputs_ta.unstack(inputs)
|
||||
# Verify emit shapes may be unknown by feeding a placeholder that
|
||||
# determines an emit shape.
|
||||
unknown_dim = array_ops.placeholder(dtype=dtypes.int32)
|
||||
|
||||
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||
|
||||
@ -2047,12 +2050,12 @@ class RawRNNTest(test.TestCase):
|
||||
if cell_output is None:
|
||||
emit_output = (array_ops.zeros(
|
||||
[2, 3], dtype=dtypes.int32), array_ops.zeros(
|
||||
[1], dtype=dtypes.int64))
|
||||
[unknown_dim], dtype=dtypes.int64))
|
||||
next_state = cell.zero_state(batch_size, dtypes.float32)
|
||||
else:
|
||||
emit_output = (array_ops.ones(
|
||||
[batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones(
|
||||
[batch_size, 1], dtype=dtypes.int64))
|
||||
[batch_size, unknown_dim], dtype=dtypes.int64))
|
||||
next_state = cell_state
|
||||
elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
|
||||
finished = math_ops.reduce_all(elements_finished)
|
||||
@ -2069,7 +2072,7 @@ class RawRNNTest(test.TestCase):
|
||||
self.assertEqual([dtypes.int32, dtypes.int64],
|
||||
[ta.dtype for ta in output_ta])
|
||||
output = [ta.stack() for ta in output_ta]
|
||||
output_vals = sess.run(output)
|
||||
output_vals = sess.run(output, feed_dict={unknown_dim: 1})
|
||||
self.assertAllEqual(
|
||||
np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0])
|
||||
self.assertAllEqual(
|
||||
|
@ -984,7 +984,8 @@ def raw_rnn(cell, loop_fn,
|
||||
|
||||
if emit_structure is not None:
|
||||
flat_emit_structure = nest.flatten(emit_structure)
|
||||
flat_emit_size = [emit.get_shape() for emit in flat_emit_structure]
|
||||
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
|
||||
array_ops.shape(emit) for emit in flat_emit_structure]
|
||||
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
|
||||
else:
|
||||
emit_structure = cell.output_size
|
||||
|
Loading…
Reference in New Issue
Block a user