Add softmax temperature to SampleEmbeddingHelper.

PiperOrigin-RevId: 165001965
This commit is contained in:
Adam Roberts 2017-08-11 11:25:01 -07:00 committed by TensorFlower Gardener
parent 5d1792f994
commit 31b3d407f6

View File

@ -520,7 +520,8 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
result through an embedding layer to get the next input.
"""
def __init__(self, embedding, start_tokens, end_token, seed=None):
def __init__(self, embedding, start_tokens, end_token,
softmax_temperature=None, seed=None):
"""Initializer.
Args:
@ -529,7 +530,12 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
will be passed to the decoder input.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
seed: The sampling seed.
softmax_temperature: (Optional) `float32` scalar, value to divide the
logits by before computing the softmax. Larger values (above 1.0) result
in more random samples, while smaller values push the sampling
distribution towards the argmax. Must be strictly greater than 0.
Defaults to 1.0.
seed: (Optional) The sampling seed.
Raises:
ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
@ -537,6 +543,7 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
"""
super(SampleEmbeddingHelper, self).__init__(
embedding, start_tokens, end_token)
self._softmax_temperature = softmax_temperature
self._seed = seed
def sample(self, time, outputs, state, name=None):
@ -546,7 +553,12 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
sample_id_sampler = categorical.Categorical(logits=outputs)
if self._softmax_temperature is None:
logits = outputs
else:
logits = outputs / self._softmax_temperature
sample_id_sampler = categorical.Categorical(logits=logits)
sample_ids = sample_id_sampler.sample(seed=self._seed)
return sample_ids