Use tf.where instead of multiplies when masking probabilities in the BeamSearchDecoder.

PiperOrigin-RevId: 173273139
This commit is contained in:
A. Unique TensorFlower 2017-10-24 10:11:42 -07:00 committed by TensorFlower Gardener
parent 58b071639d
commit 377dd3d0d5

View File

@ -715,12 +715,6 @@ def _mask_probs(probs, eos_token, finished):
probability on the EOS token.
"""
vocab_size = array_ops.shape(probs)[2]
finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype)
not_finished_mask = math_ops.cast(
array_ops.expand_dims(math_ops.logical_not(finished), 2),
probs.dtype)
# These examples are not finished and we leave them
non_finished_examples = not_finished_mask * probs
# All finished examples are replaced with a vector that has all
# probability on EOS
finished_row = array_ops.one_hot(
@ -729,8 +723,13 @@ def _mask_probs(probs, eos_token, finished):
dtype=probs.dtype,
on_value=0.,
off_value=probs.dtype.min)
finished_examples = finished_mask * finished_row
return finished_examples + non_finished_examples
finished_probs = array_ops.tile(
array_ops.reshape(finished_row, [1, 1, -1]),
array_ops.concat([array_ops.shape(finished), [1]], 0))
finished_mask = array_ops.tile(
array_ops.expand_dims(finished, 2), [1, 1, vocab_size])
return array_ops.where(finished_mask, finished_probs, probs)
def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,