[tf contrib seq2seq] Expand tile_batch to handle nested structures.
This allows it to properly tile the initial wrapper state when using BeamSearchDecoder with AttentionWrapper. Unit tests updated to show this use. PiperOrigin-RevId: 157903115
This commit is contained in:
parent
1234e2dda6
commit
563f05ff67
@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
|
||||
class BeamSearchDecoderTest(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, has_attention):
|
||||
encoder_sequence_length = [3, 2, 3, 1, 1]
|
||||
decoder_sequence_length = [2, 0, 1, 2, 3]
|
||||
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
|
||||
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
|
||||
batch_size = 5
|
||||
decoder_max_time = 4
|
||||
input_depth = 7
|
||||
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
batch_size_tensor = constant_op.constant(batch_size)
|
||||
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||
cell = rnn_cell.LSTMCell(cell_depth)
|
||||
initial_state = cell.zero_state(batch_size, dtypes.float32)
|
||||
if has_attention:
|
||||
inputs = array_ops.placeholder_with_default(
|
||||
np.random.randn(batch_size, decoder_max_time,
|
||||
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
num_units=attention_depth,
|
||||
memory=tiled_inputs,
|
||||
memory_sequence_length=tiled_sequence_length)
|
||||
initial_state = beam_search_decoder.tile_batch(
|
||||
initial_state, multiplier=beam_width)
|
||||
cell = attention_wrapper.AttentionWrapper(
|
||||
cell=cell,
|
||||
attention_mechanism=attention_mechanism,
|
||||
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
alignment_history=False)
|
||||
cell_state = cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
||||
if has_attention:
|
||||
cell_state = cell_state.clone(
|
||||
cell_state=initial_state)
|
||||
bsd = beam_search_decoder.BeamSearchDecoder(
|
||||
cell=cell,
|
||||
embedding=embedding,
|
||||
|
@ -72,10 +72,30 @@ class FinalBeamSearchDecoderOutput(
|
||||
pass
|
||||
|
||||
|
||||
def tile_batch(t, multiplier, name=None):
|
||||
"""Tile the batch dimension of tensor t.
|
||||
def _tile_batch(t, multiplier):
|
||||
"""Core single-tensor implementation of tile_batch."""
|
||||
t = ops.convert_to_tensor(t, name="t")
|
||||
shape_t = array_ops.shape(t)
|
||||
if t.shape.ndims is None or t.shape.ndims < 1:
|
||||
raise ValueError("t must have statically known rank")
|
||||
tiling = [1] * (t.shape.ndims + 1)
|
||||
tiling[1] = multiplier
|
||||
tiled_static_batch_size = (
|
||||
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
|
||||
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
|
||||
tiled = array_ops.reshape(
|
||||
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
|
||||
tiled.set_shape(
|
||||
tensor_shape.TensorShape(
|
||||
[tiled_static_batch_size]).concatenate(t.shape[1:]))
|
||||
return tiled
|
||||
|
||||
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
||||
|
||||
def tile_batch(t, multiplier, name=None):
|
||||
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
|
||||
|
||||
For each tensor t in a (possibly nested structure) of tensors,
|
||||
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
||||
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
|
||||
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
|
||||
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
|
||||
@ -87,27 +107,16 @@ def tile_batch(t, multiplier, name=None):
|
||||
name: Name scope for any created operations.
|
||||
|
||||
Returns:
|
||||
A `Tensor` shaped `[batch_size * multiplier, ...]`.
|
||||
A (possibly nested structure of) `Tensor` shaped
|
||||
`[batch_size * multiplier, ...]`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `t` does not have a statically known rank or it's < 1.
|
||||
ValueError: if tensor(s) `t` do not have a statically known rank or
|
||||
the rank is < 1.
|
||||
"""
|
||||
with ops.name_scope(name, "tile_batch", [t, multiplier]):
|
||||
t = ops.convert_to_tensor(t, name="t")
|
||||
shape_t = array_ops.shape(t)
|
||||
if t.shape.ndims is None or t.shape.ndims < 1:
|
||||
raise ValueError("t must have statically known rank")
|
||||
tiling = [1] * (t.shape.ndims + 1)
|
||||
tiling[1] = multiplier
|
||||
tiled_static_batch_size = (
|
||||
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
|
||||
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
|
||||
tiled = array_ops.reshape(
|
||||
tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
|
||||
tiled.set_shape(
|
||||
tensor_shape.TensorShape(
|
||||
[tiled_static_batch_size]).concatenate(t.shape[1:]))
|
||||
return tiled
|
||||
flat_t = nest.flatten(t)
|
||||
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
|
||||
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
|
||||
|
||||
|
||||
def _check_maybe(t):
|
||||
|
Loading…
Reference in New Issue
Block a user