From 563f05ff67b0c2c3b52da71337c4de1c43d09f1a Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Fri, 2 Jun 2017 20:22:41 -0700 Subject: [PATCH] [tf contrib seq2seq] Expand tile_batch to handle nested structures. This allows it to properly tile the initial wrapper state when using BeamSearchDecoder with AttentionWrapper. Unit tests updated to show this use. PiperOrigin-RevId: 157903115 --- .../kernel_tests/beam_search_decoder_test.py | 10 +++- .../seq2seq/python/ops/beam_search_decoder.py | 51 +++++++++++-------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index aeafe7c3e59..3d0627467aa 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -226,8 +226,8 @@ class TestBeamStep(test.TestCase): class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention): - encoder_sequence_length = [3, 2, 3, 1, 1] - decoder_sequence_length = [2, 0, 1, 2, 3] + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) batch_size = 5 decoder_max_time = 4 input_depth = 7 @@ -245,6 +245,7 @@ class BeamSearchDecoderTest(test.TestCase): batch_size_tensor = constant_op.constant(batch_size) embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) if has_attention: inputs = array_ops.placeholder_with_default( np.random.randn(batch_size, decoder_max_time, @@ -258,6 +259,8 @@ class BeamSearchDecoderTest(test.TestCase): num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, @@ -265,6 +268,9 @@ class BeamSearchDecoderTest(test.TestCase): alignment_history=False) cell_state = cell.zero_state( dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone( + cell_state=initial_state) bsd = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=embedding, diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index f1d0ab07711..1d1babda163 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -72,10 +72,30 @@ class FinalBeamSearchDecoderOutput( pass -def tile_batch(t, multiplier, name=None): - """Tile the batch dimension of tensor t. +def _tile_batch(t, multiplier): + """Core single-tensor implementation of tile_batch.""" + t = ops.convert_to_tensor(t, name="t") + shape_t = array_ops.shape(t) + if t.shape.ndims is None or t.shape.ndims < 1: + raise ValueError("t must have statically known rank") + tiling = [1] * (t.shape.ndims + 1) + tiling[1] = multiplier + tiled_static_batch_size = ( + t.shape[0].value * multiplier if t.shape[0].value is not None else None) + tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) + tiled = array_ops.reshape( + tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) + tiled.set_shape( + tensor_shape.TensorShape( + [tiled_static_batch_size]).concatenate(t.shape[1:])) + return tiled - This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of + +def tile_batch(t, multiplier, name=None): + """Tile the batch dimension of a (possibly nested structure of) tensor(s) t. + + For each tensor t in a (possibly nested structure) of tensors, + this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated @@ -87,27 +107,16 @@ def tile_batch(t, multiplier, name=None): name: Name scope for any created operations. Returns: - A `Tensor` shaped `[batch_size * multiplier, ...]`. + A (possibly nested structure of) `Tensor` shaped + `[batch_size * multiplier, ...]`. Raises: - ValueError: if `t` does not have a statically known rank or it's < 1. + ValueError: if tensor(s) `t` do not have a statically known rank or + the rank is < 1. """ - with ops.name_scope(name, "tile_batch", [t, multiplier]): - t = ops.convert_to_tensor(t, name="t") - shape_t = array_ops.shape(t) - if t.shape.ndims is None or t.shape.ndims < 1: - raise ValueError("t must have statically known rank") - tiling = [1] * (t.shape.ndims + 1) - tiling[1] = multiplier - tiled_static_batch_size = ( - t.shape[0].value * multiplier if t.shape[0].value is not None else None) - tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) - tiled = array_ops.reshape( - tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) - tiled.set_shape( - tensor_shape.TensorShape( - [tiled_static_batch_size]).concatenate(t.shape[1:])) - return tiled + flat_t = nest.flatten(t) + with ops.name_scope(name, "tile_batch", flat_t + [multiplier]): + return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) def _check_maybe(t):