[tf contrib seq2seq] Expand tile_batch to handle nested structures.

This allows it to properly tile the initial wrapper state when using
BeamSearchDecoder with AttentionWrapper.  Unit tests updated to show this use.

PiperOrigin-RevId: 157903115
This commit is contained in:
Eugene Brevdo 2017-06-02 20:22:41 -07:00 committed by TensorFlower Gardener
parent 1234e2dda6
commit 563f05ff67
2 changed files with 38 additions and 23 deletions

View File

@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention):
encoder_sequence_length = [3, 2, 3, 1, 1]
decoder_sequence_length = [2, 0, 1, 2, 3]
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
batch_size = 5
decoder_max_time = 4
input_depth = 7
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
if has_attention:
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
initial_state = beam_search_decoder.tile_batch(
initial_state, multiplier=beam_width)
cell = attention_wrapper.AttentionWrapper(
cell=cell,
attention_mechanism=attention_mechanism,
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
alignment_history=False)
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
cell_state = cell_state.clone(
cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,

View File

@ -72,27 +72,8 @@ class FinalBeamSearchDecoderOutput(
pass
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of tensor t.
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
`multiplier` times.
Args:
t: `Tensor` shaped `[batch_size, ...]`.
multiplier: Python int.
name: Name scope for any created operations.
Returns:
A `Tensor` shaped `[batch_size * multiplier, ...]`.
Raises:
ValueError: if `t` does not have a statically known rank or it's < 1.
"""
with ops.name_scope(name, "tile_batch", [t, multiplier]):
def _tile_batch(t, multiplier):
"""Core single-tensor implementation of tile_batch."""
t = ops.convert_to_tensor(t, name="t")
shape_t = array_ops.shape(t)
if t.shape.ndims is None or t.shape.ndims < 1:
@ -110,6 +91,34 @@ def tile_batch(t, multiplier, name=None):
return tiled
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
For each tensor t in a (possibly nested structure) of tensors,
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
`multiplier` times.
Args:
t: `Tensor` shaped `[batch_size, ...]`.
multiplier: Python int.
name: Name scope for any created operations.
Returns:
A (possibly nested structure of) `Tensor` shaped
`[batch_size * multiplier, ...]`.
Raises:
ValueError: if tensor(s) `t` do not have a statically known rank or
the rank is < 1.
"""
flat_t = nest.flatten(t)
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
def _check_maybe(t):
if isinstance(t, tensor_array_ops.TensorArray):
raise TypeError(