diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index 9d3f8ad4411..a716dcba738 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -520,7 +520,8 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): result through an embedding layer to get the next input. """ - def __init__(self, embedding, start_tokens, end_token, seed=None): + def __init__(self, embedding, start_tokens, end_token, + softmax_temperature=None, seed=None): """Initializer. Args: @@ -529,7 +530,12 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): will be passed to the decoder input. start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. end_token: `int32` scalar, the token that marks end of decoding. - seed: The sampling seed. + softmax_temperature: (Optional) `float32` scalar, value to divide the + logits by before computing the softmax. Larger values (above 1.0) result + in more random samples, while smaller values push the sampling + distribution towards the argmax. Must be strictly greater than 0. + Defaults to 1.0. + seed: (Optional) The sampling seed. Raises: ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a @@ -537,6 +543,7 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): """ super(SampleEmbeddingHelper, self).__init__( embedding, start_tokens, end_token) + self._softmax_temperature = softmax_temperature self._seed = seed def sample(self, time, outputs, state, name=None): @@ -546,7 +553,12 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): if not isinstance(outputs, ops.Tensor): raise TypeError("Expected outputs to be a single Tensor, got: %s" % type(outputs)) - sample_id_sampler = categorical.Categorical(logits=outputs) + if self._softmax_temperature is None: + logits = outputs + else: + logits = outputs / self._softmax_temperature + + sample_id_sampler = categorical.Categorical(logits=logits) sample_ids = sample_id_sampler.sample(seed=self._seed) return sample_ids