From 74bd8ff717eaf08bf64f4b16c0bca40173b19614 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Mon, 16 Oct 2017 11:32:49 -0700 Subject: [PATCH] [tf.contrib.seq2seq] Some light cleanup in beam search decoder code. PiperOrigin-RevId: 172352767 --- .../kernel_tests/beam_search_decoder_test.py | 3 +- .../seq2seq/python/ops/beam_search_decoder.py | 71 ++++++++++--------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 2caeb9eb614..8d4ec4b4dbe 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -80,8 +80,7 @@ class TestEosMasking(test.TestCase): ]) eos_token = 0 - previously_finished = constant_op.constant( - [[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32) + previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index e22912ac5c9..112ac57a1bc 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -20,9 +20,10 @@ from __future__ import print_function import collections +import numpy as np + from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -390,17 +391,17 @@ class BeamSearchDecoder(decoder.Decoder): We do this so that we can use nest and not run into problems with shapes. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - Either a reshaped version of t with dimension - [batch_size, beam_width, s] if t's first dimension is of size - batch_size*beam_width or t if not. + If `t` is a matrix or higher order tensor, then the return value is + `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is + returned unchanged. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 1: @@ -411,19 +412,19 @@ class BeamSearchDecoder(decoder.Decoder): def _maybe_merge_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. - More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We - reshape this into [batch_size, beam_width, s] + More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`, + then we reshape it to `[batch_size, beam_width] + s`. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor` of dimension `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - A reshaped version of t with dimension [batch_size, beam_width, s]. + A reshaped version of t with shape `[batch_size, beam_width] + s`. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 2: @@ -521,14 +522,12 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot( - indices=array_ops.tile( - array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]), + indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, - on_value=constant_op.constant(0, dtype=dtypes.int64), - off_value=constant_op.constant(1, dtype=dtypes.int64), + on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) - add_mask = (1 - math_ops.to_int64(previously_finished)) - lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add + add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) + lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) @@ -592,9 +591,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # 1. Finished beams remain unchanged # 2. Beams that are now finished (EOS predicted) remain unchanged # 3. Beams that are not yet finished have their length increased by 1 - lengths_to_add = math_ops.to_int64( - math_ops.not_equal(next_word_ids, end_token)) - lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add + lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished)) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, @@ -652,13 +649,20 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): def _length_penalty(sequence_lengths, penalty_factor): """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. + Returns the length penalty tensor: + ``` + [(5+sequence_lengths)/6]**penalty_factor + ``` + where all operations are performed element-wise. + Args: - sequence_lengths: The sequence length of all hypotheses, a tensor - of shape [beam_size, vocab_size]. + sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. penalty_factor: A scalar that weights the length penalty. Returns: - The length penalty factor, a tensor fo shape [beam_size]. + If the penalty is `0`, returns the scalar `1.0`. Otherwise returns + the length penalty factor, a tensor with the same shape as + `sequence_lengths`. """ penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor") penalty_factor.set_shape(()) # penalty should be a scalar. @@ -680,8 +684,7 @@ def _mask_probs(probs, eos_token, finished): eos_token: An int32 id corresponding to the EOS token to allocate probability to. finished: A boolean tensor of shape `[batch_size, beam_width]` that - specifies which - elements in the beam are finished already. + specifies which elements in the beam are finished already. Returns: A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished @@ -689,10 +692,12 @@ def _mask_probs(probs, eos_token, finished): probability on the EOS token. """ vocab_size = array_ops.shape(probs)[2] - finished_mask = array_ops.expand_dims( - math_ops.to_float(1. - math_ops.to_float(finished)), 2) + finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype) + not_finished_mask = math_ops.cast( + array_ops.expand_dims(math_ops.logical_not(finished), 2), + probs.dtype) # These examples are not finished and we leave them - non_finished_examples = finished_mask * probs + non_finished_examples = not_finished_mask * probs # All finished examples are replaced with a vector that has all # probability on EOS finished_row = array_ops.one_hot( @@ -701,7 +706,7 @@ def _mask_probs(probs, eos_token, finished): dtype=probs.dtype, on_value=0., off_value=probs.dtype.min) - finished_examples = (1. - finished_mask) * finished_row + finished_examples = finished_mask * finished_row return finished_examples + non_finished_examples