[tf contrib seq2seq] Expand tile_batch to handle nested structures.
This allows it to properly tile the initial wrapper state when using BeamSearchDecoder with AttentionWrapper. Unit tests updated to show this use. PiperOrigin-RevId: 157903115
This commit is contained in:
parent
1234e2dda6
commit
563f05ff67
@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase):
|
|||||||
class BeamSearchDecoderTest(test.TestCase):
|
class BeamSearchDecoderTest(test.TestCase):
|
||||||
|
|
||||||
def _testDynamicDecodeRNN(self, time_major, has_attention):
|
def _testDynamicDecodeRNN(self, time_major, has_attention):
|
||||||
encoder_sequence_length = [3, 2, 3, 1, 1]
|
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
|
||||||
decoder_sequence_length = [2, 0, 1, 2, 3]
|
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
decoder_max_time = 4
|
decoder_max_time = 4
|
||||||
input_depth = 7
|
input_depth = 7
|
||||||
@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
batch_size_tensor = constant_op.constant(batch_size)
|
batch_size_tensor = constant_op.constant(batch_size)
|
||||||
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
|
||||||
cell = rnn_cell.LSTMCell(cell_depth)
|
cell = rnn_cell.LSTMCell(cell_depth)
|
||||||
|
initial_state = cell.zero_state(batch_size, dtypes.float32)
|
||||||
if has_attention:
|
if has_attention:
|
||||||
inputs = array_ops.placeholder_with_default(
|
inputs = array_ops.placeholder_with_default(
|
||||||
np.random.randn(batch_size, decoder_max_time,
|
np.random.randn(batch_size, decoder_max_time,
|
||||||
@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
num_units=attention_depth,
|
num_units=attention_depth,
|
||||||
memory=tiled_inputs,
|
memory=tiled_inputs,
|
||||||
memory_sequence_length=tiled_sequence_length)
|
memory_sequence_length=tiled_sequence_length)
|
||||||
|
initial_state = beam_search_decoder.tile_batch(
|
||||||
|
initial_state, multiplier=beam_width)
|
||||||
cell = attention_wrapper.AttentionWrapper(
|
cell = attention_wrapper.AttentionWrapper(
|
||||||
cell=cell,
|
cell=cell,
|
||||||
attention_mechanism=attention_mechanism,
|
attention_mechanism=attention_mechanism,
|
||||||
@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase):
|
|||||||
alignment_history=False)
|
alignment_history=False)
|
||||||
cell_state = cell.zero_state(
|
cell_state = cell.zero_state(
|
||||||
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
|
||||||
|
if has_attention:
|
||||||
|
cell_state = cell_state.clone(
|
||||||
|
cell_state=initial_state)
|
||||||
bsd = beam_search_decoder.BeamSearchDecoder(
|
bsd = beam_search_decoder.BeamSearchDecoder(
|
||||||
cell=cell,
|
cell=cell,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
|
@ -72,27 +72,8 @@ class FinalBeamSearchDecoderOutput(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def tile_batch(t, multiplier, name=None):
|
def _tile_batch(t, multiplier):
|
||||||
"""Tile the batch dimension of tensor t.
|
"""Core single-tensor implementation of tile_batch."""
|
||||||
|
|
||||||
This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
|
||||||
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
|
|
||||||
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
|
|
||||||
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
|
|
||||||
`multiplier` times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
t: `Tensor` shaped `[batch_size, ...]`.
|
|
||||||
multiplier: Python int.
|
|
||||||
name: Name scope for any created operations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Tensor` shaped `[batch_size * multiplier, ...]`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `t` does not have a statically known rank or it's < 1.
|
|
||||||
"""
|
|
||||||
with ops.name_scope(name, "tile_batch", [t, multiplier]):
|
|
||||||
t = ops.convert_to_tensor(t, name="t")
|
t = ops.convert_to_tensor(t, name="t")
|
||||||
shape_t = array_ops.shape(t)
|
shape_t = array_ops.shape(t)
|
||||||
if t.shape.ndims is None or t.shape.ndims < 1:
|
if t.shape.ndims is None or t.shape.ndims < 1:
|
||||||
@ -110,6 +91,34 @@ def tile_batch(t, multiplier, name=None):
|
|||||||
return tiled
|
return tiled
|
||||||
|
|
||||||
|
|
||||||
|
def tile_batch(t, multiplier, name=None):
|
||||||
|
"""Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
|
||||||
|
|
||||||
|
For each tensor t in a (possibly nested structure) of tensors,
|
||||||
|
this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
|
||||||
|
minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
|
||||||
|
`[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
|
||||||
|
`t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
|
||||||
|
`multiplier` times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: `Tensor` shaped `[batch_size, ...]`.
|
||||||
|
multiplier: Python int.
|
||||||
|
name: Name scope for any created operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A (possibly nested structure of) `Tensor` shaped
|
||||||
|
`[batch_size * multiplier, ...]`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if tensor(s) `t` do not have a statically known rank or
|
||||||
|
the rank is < 1.
|
||||||
|
"""
|
||||||
|
flat_t = nest.flatten(t)
|
||||||
|
with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
|
||||||
|
return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
|
||||||
|
|
||||||
|
|
||||||
def _check_maybe(t):
|
def _check_maybe(t):
|
||||||
if isinstance(t, tensor_array_ops.TensorArray):
|
if isinstance(t, tensor_array_ops.TensorArray):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
Loading…
Reference in New Issue
Block a user