Support unknown emit shapes in tf.nn.raw_rnn.
PiperOrigin-RevId: 158308002
This commit is contained in:
parent
edb5fed7fc
commit
85e832201e
@ -2040,6 +2040,9 @@ class RawRNNTest(test.TestCase):
|
|||||||
inputs_ta = tensor_array_ops.TensorArray(
|
inputs_ta = tensor_array_ops.TensorArray(
|
||||||
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
||||||
inputs_ta = inputs_ta.unstack(inputs)
|
inputs_ta = inputs_ta.unstack(inputs)
|
||||||
|
# Verify emit shapes may be unknown by feeding a placeholder that
|
||||||
|
# determines an emit shape.
|
||||||
|
unknown_dim = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
|
|
||||||
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
|
||||||
|
|
||||||
@ -2047,12 +2050,12 @@ class RawRNNTest(test.TestCase):
|
|||||||
if cell_output is None:
|
if cell_output is None:
|
||||||
emit_output = (array_ops.zeros(
|
emit_output = (array_ops.zeros(
|
||||||
[2, 3], dtype=dtypes.int32), array_ops.zeros(
|
[2, 3], dtype=dtypes.int32), array_ops.zeros(
|
||||||
[1], dtype=dtypes.int64))
|
[unknown_dim], dtype=dtypes.int64))
|
||||||
next_state = cell.zero_state(batch_size, dtypes.float32)
|
next_state = cell.zero_state(batch_size, dtypes.float32)
|
||||||
else:
|
else:
|
||||||
emit_output = (array_ops.ones(
|
emit_output = (array_ops.ones(
|
||||||
[batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones(
|
[batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones(
|
||||||
[batch_size, 1], dtype=dtypes.int64))
|
[batch_size, unknown_dim], dtype=dtypes.int64))
|
||||||
next_state = cell_state
|
next_state = cell_state
|
||||||
elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
|
elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
|
||||||
finished = math_ops.reduce_all(elements_finished)
|
finished = math_ops.reduce_all(elements_finished)
|
||||||
@ -2069,7 +2072,7 @@ class RawRNNTest(test.TestCase):
|
|||||||
self.assertEqual([dtypes.int32, dtypes.int64],
|
self.assertEqual([dtypes.int32, dtypes.int64],
|
||||||
[ta.dtype for ta in output_ta])
|
[ta.dtype for ta in output_ta])
|
||||||
output = [ta.stack() for ta in output_ta]
|
output = [ta.stack() for ta in output_ta]
|
||||||
output_vals = sess.run(output)
|
output_vals = sess.run(output, feed_dict={unknown_dim: 1})
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0])
|
np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0])
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
|
@ -984,7 +984,8 @@ def raw_rnn(cell, loop_fn,
|
|||||||
|
|
||||||
if emit_structure is not None:
|
if emit_structure is not None:
|
||||||
flat_emit_structure = nest.flatten(emit_structure)
|
flat_emit_structure = nest.flatten(emit_structure)
|
||||||
flat_emit_size = [emit.get_shape() for emit in flat_emit_structure]
|
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
|
||||||
|
array_ops.shape(emit) for emit in flat_emit_structure]
|
||||||
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
|
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
|
||||||
else:
|
else:
|
||||||
emit_structure = cell.output_size
|
emit_structure = cell.output_size
|
||||||
|
Loading…
Reference in New Issue
Block a user