Use tf.where instead of multiplies when masking probabilities in the BeamSearchDecoder.
PiperOrigin-RevId: 173273139
This commit is contained in:
parent
58b071639d
commit
377dd3d0d5
@ -715,12 +715,6 @@ def _mask_probs(probs, eos_token, finished):
|
||||
probability on the EOS token.
|
||||
"""
|
||||
vocab_size = array_ops.shape(probs)[2]
|
||||
finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype)
|
||||
not_finished_mask = math_ops.cast(
|
||||
array_ops.expand_dims(math_ops.logical_not(finished), 2),
|
||||
probs.dtype)
|
||||
# These examples are not finished and we leave them
|
||||
non_finished_examples = not_finished_mask * probs
|
||||
# All finished examples are replaced with a vector that has all
|
||||
# probability on EOS
|
||||
finished_row = array_ops.one_hot(
|
||||
@ -729,8 +723,13 @@ def _mask_probs(probs, eos_token, finished):
|
||||
dtype=probs.dtype,
|
||||
on_value=0.,
|
||||
off_value=probs.dtype.min)
|
||||
finished_examples = finished_mask * finished_row
|
||||
return finished_examples + non_finished_examples
|
||||
finished_probs = array_ops.tile(
|
||||
array_ops.reshape(finished_row, [1, 1, -1]),
|
||||
array_ops.concat([array_ops.shape(finished), [1]], 0))
|
||||
finished_mask = array_ops.tile(
|
||||
array_ops.expand_dims(finished, 2), [1, 1, vocab_size])
|
||||
|
||||
return array_ops.where(finished_mask, finished_probs, probs)
|
||||
|
||||
|
||||
def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
|
||||
|
Loading…
Reference in New Issue
Block a user