diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index a88d4f5b8b5..5be0c92243d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -715,12 +715,6 @@ def _mask_probs(probs, eos_token, finished): probability on the EOS token. """ vocab_size = array_ops.shape(probs)[2] - finished_mask = math_ops.cast(array_ops.expand_dims(finished, 2), probs.dtype) - not_finished_mask = math_ops.cast( - array_ops.expand_dims(math_ops.logical_not(finished), 2), - probs.dtype) - # These examples are not finished and we leave them - non_finished_examples = not_finished_mask * probs # All finished examples are replaced with a vector that has all # probability on EOS finished_row = array_ops.one_hot( @@ -729,8 +723,13 @@ def _mask_probs(probs, eos_token, finished): dtype=probs.dtype, on_value=0., off_value=probs.dtype.min) - finished_examples = finished_mask * finished_row - return finished_examples + non_finished_examples + finished_probs = array_ops.tile( + array_ops.reshape(finished_row, [1, 1, -1]), + array_ops.concat([array_ops.shape(finished), [1]], 0)) + finished_mask = array_ops.tile( + array_ops.expand_dims(finished, 2), [1, 1, vocab_size]) + + return array_ops.where(finished_mask, finished_probs, probs) def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,