[tf contrib seq2seq] Expand tile_batch to handle nested structures.

This allows it to properly tile the initial wrapper state when using
BeamSearchDecoder with AttentionWrapper.  Unit tests updated to show this use.

PiperOrigin-RevId: 157903115
This commit is contained in:
Eugene Brevdo 2017-06-02 20:22:41 -07:00 committed by TensorFlower Gardener
parent 1234e2dda6
commit 563f05ff67
2 changed files with 38 additions and 23 deletions

View File

@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention):
encoder_sequence_length = [3, 2, 3, 1, 1]
decoder_sequence_length = [2, 0, 1, 2, 3]
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
batch_size = 5
decoder_max_time = 4
input_depth = 7
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
if has_attention:
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
initial_state = beam_search_decoder.tile_batch(
initial_state, multiplier=beam_width)
cell = attention_wrapper.AttentionWrapper(
cell=cell,
attention_mechanism=attention_mechanism,
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
alignment_history=False)
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
cell_state = cell_state.clone(
cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,

View File

@ -72,10 +72,30 @@ class FinalBeamSearchDecoderOutput(
pass
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of tensor t.
def _tile_batch(t, multiplier):
"""Core single-tensor implementation of tile_batch."""
t = ops.convert_to_tensor(t, name="t")
shape_t = array_ops.shape(t)
if t.shape.ndims is None or t.shape.ndims < 1:
raise ValueError("t must have statically known rank")
tiling = [1] * (t.shape.ndims + 1)
tiling[1] = multiplier
tiled_static_batch_size = (
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
tiled = array_ops.reshape(
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
tiled.set_shape(
tensor_shape.TensorShape(
[tiled_static_batch_size]).concatenate(t.shape[1:]))
return tiled
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
def tile_batch(t, multiplier, name=None):
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
For each tensor t in a (possibly nested structure) of tensors,
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
@ -87,27 +107,16 @@ def tile_batch(t, multiplier, name=None):
name: Name scope for any created operations.
Returns:
A `Tensor` shaped `[batch_size * multiplier, ...]`.
A (possibly nested structure of) `Tensor` shaped
`[batch_size * multiplier, ...]`.
Raises:
ValueError: if `t` does not have a statically known rank or it's < 1.
ValueError: if tensor(s) `t` do not have a statically known rank or
the rank is < 1.
"""
with ops.name_scope(name, "tile_batch", [t, multiplier]):
t = ops.convert_to_tensor(t, name="t")
shape_t = array_ops.shape(t)
if t.shape.ndims is None or t.shape.ndims < 1:
raise ValueError("t must have statically known rank")
tiling = [1] * (t.shape.ndims + 1)
tiling[1] = multiplier
tiled_static_batch_size = (
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
tiled = array_ops.reshape(
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
tiled.set_shape(
tensor_shape.TensorShape(
[tiled_static_batch_size]).concatenate(t.shape[1:]))
return tiled
flat_t = nest.flatten(t)
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
def _check_maybe(t):