[tf.contrib.seq2seq] Some light cleanup in beam search decoder code.
PiperOrigin-RevId: 172352767
This commit is contained in:
parent
2487732ff1
commit
74bd8ff717
@ -80,8 +80,7 @@ class TestEosMasking(test.TestCase):
|
||||
])
|
||||
|
||||
eos_token = 0
|
||||
previously_finished = constant_op.constant(
|
||||
[[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32)
|
||||
previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool)
|
||||
masked = beam_search_decoder._mask_probs(probs, eos_token,
|
||||
previously_finished)
|
||||
|
||||
|
@ -20,9 +20,10 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -390,17 +391,17 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
We do this so that we can use nest and not run into problems with shapes.
|
||||
|
||||
Args:
|
||||
t: Tensor of dimension [batch_size*beam_width, s]
|
||||
s: Tensor, Python int, or TensorShape.
|
||||
t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`.
|
||||
s: `Tensor`, Python int, or `TensorShape`.
|
||||
|
||||
Returns:
|
||||
Either a reshaped version of t with dimension
|
||||
[batch_size, beam_width, s] if t's first dimension is of size
|
||||
batch_size*beam_width or t if not.
|
||||
If `t` is a matrix or higher order tensor, then the return value is
|
||||
`t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is
|
||||
returned unchanged.
|
||||
|
||||
Raises:
|
||||
TypeError: If t is an instance of TensorArray.
|
||||
ValueError: If the rank of t is not statically known.
|
||||
TypeError: If `t` is an instance of `TensorArray`.
|
||||
ValueError: If the rank of `t` is not statically known.
|
||||
"""
|
||||
_check_maybe(t)
|
||||
if t.shape.ndims >= 1:
|
||||
@ -411,19 +412,19 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
def _maybe_merge_batch_beams(self, t, s):
|
||||
"""Splits the tensor from a batch by beams into a batch of beams.
|
||||
|
||||
More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
|
||||
reshape this into [batch_size, beam_width, s]
|
||||
More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`,
|
||||
then we reshape it to `[batch_size, beam_width] + s`.
|
||||
|
||||
Args:
|
||||
t: Tensor of dimension [batch_size*beam_width, s]
|
||||
s: Tensor, Python int, or TensorShape.
|
||||
t: `Tensor` of dimension `[batch_size * beam_width] + s`.
|
||||
s: `Tensor`, Python int, or `TensorShape`.
|
||||
|
||||
Returns:
|
||||
A reshaped version of t with dimension [batch_size, beam_width, s].
|
||||
A reshaped version of t with shape `[batch_size, beam_width] + s`.
|
||||
|
||||
Raises:
|
||||
TypeError: If t is an instance of TensorArray.
|
||||
ValueError: If the rank of t is not statically known.
|
||||
TypeError: If `t` is an instance of `TensorArray`.
|
||||
ValueError: If the rank of `t` is not statically known.
|
||||
"""
|
||||
_check_maybe(t)
|
||||
if t.shape.ndims >= 2:
|
||||
@ -521,14 +522,12 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||||
# Calculate the continuation lengths by adding to all continuing beams.
|
||||
vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
|
||||
lengths_to_add = array_ops.one_hot(
|
||||
indices=array_ops.tile(
|
||||
array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]),
|
||||
indices=array_ops.fill([batch_size, beam_width], end_token),
|
||||
depth=vocab_size,
|
||||
on_value=constant_op.constant(0, dtype=dtypes.int64),
|
||||
off_value=constant_op.constant(1, dtype=dtypes.int64),
|
||||
on_value=np.int64(0), off_value=np.int64(1),
|
||||
dtype=dtypes.int64)
|
||||
add_mask = (1 - math_ops.to_int64(previously_finished))
|
||||
lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add
|
||||
add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
|
||||
lengths_to_add *= array_ops.expand_dims(add_mask, 2)
|
||||
new_prediction_lengths = (
|
||||
lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
|
||||
|
||||
@ -592,9 +591,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||||
# 1. Finished beams remain unchanged
|
||||
# 2. Beams that are now finished (EOS predicted) remain unchanged
|
||||
# 3. Beams that are not yet finished have their length increased by 1
|
||||
lengths_to_add = math_ops.to_int64(
|
||||
math_ops.not_equal(next_word_ids, end_token))
|
||||
lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add
|
||||
lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished))
|
||||
next_prediction_len = _tensor_gather_helper(
|
||||
gather_indices=next_beam_ids,
|
||||
gather_from=beam_state.lengths,
|
||||
@ -652,13 +649,20 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
|
||||
def _length_penalty(sequence_lengths, penalty_factor):
|
||||
"""Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
|
||||
|
||||
Returns the length penalty tensor:
|
||||
```
|
||||
[(5+sequence_lengths)/6]**penalty_factor
|
||||
```
|
||||
where all operations are performed element-wise.
|
||||
|
||||
Args:
|
||||
sequence_lengths: The sequence length of all hypotheses, a tensor
|
||||
of shape [beam_size, vocab_size].
|
||||
sequence_lengths: `Tensor`, the sequence lengths of each hypotheses.
|
||||
penalty_factor: A scalar that weights the length penalty.
|
||||
|
||||
Returns:
|
||||
The length penalty factor, a tensor fo shape [beam_size].
|
||||
If the penalty is `0`, returns the scalar `1.0`. Otherwise returns
|
||||
the length penalty factor, a tensor with the same shape as
|
||||
`sequence_lengths`.
|
||||
"""
|
||||
penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor")
|
||||
penalty_factor.set_shape(()) # penalty should be a scalar.
|
||||
@ -680,8 +684,7 @@ def _mask_probs(probs, eos_token, finished):
|
||||
eos_token: An int32 id corresponding to the EOS token to allocate
|
||||
probability to.
|
||||
finished: A boolean tensor of shape `[batch_size, beam_width]` that
|
||||
specifies which
|
||||
elements in the beam are finished already.
|
||||
specifies which elements in the beam are finished already.
|
||||
|
||||
Returns:
|
||||
A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished
|
||||
@ -689,10 +692,12 @@ def _mask_probs(probs, eos_token, finished):
|
||||
probability on the EOS token.
|
||||
"""
|
||||
vocab_size = array_ops.shape(probs)[2]
|
||||
finished_mask = array_ops.expand_dims(
|
||||
math_ops.to_float(1. - math_ops.to_float(finished)), 2)
|
||||
finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype)
|
||||
not_finished_mask = math_ops.cast(
|
||||
array_ops.expand_dims(math_ops.logical_not(finished), 2),
|
||||
probs.dtype)
|
||||
# These examples are not finished and we leave them
|
||||
non_finished_examples = finished_mask * probs
|
||||
non_finished_examples = not_finished_mask * probs
|
||||
# All finished examples are replaced with a vector that has all
|
||||
# probability on EOS
|
||||
finished_row = array_ops.one_hot(
|
||||
@ -701,7 +706,7 @@ def _mask_probs(probs, eos_token, finished):
|
||||
dtype=probs.dtype,
|
||||
on_value=0.,
|
||||
off_value=probs.dtype.min)
|
||||
finished_examples = (1. - finished_mask) * finished_row
|
||||
finished_examples = finished_mask * finished_row
|
||||
return finished_examples + non_finished_examples
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user