Add softmax temperature to SampleEmbeddingHelper.
PiperOrigin-RevId: 165001965
This commit is contained in:
parent
5d1792f994
commit
31b3d407f6
@ -520,7 +520,8 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
|
||||
result through an embedding layer to get the next input.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding, start_tokens, end_token, seed=None):
|
||||
def __init__(self, embedding, start_tokens, end_token,
|
||||
softmax_temperature=None, seed=None):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
@ -529,7 +530,12 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
|
||||
will be passed to the decoder input.
|
||||
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
|
||||
end_token: `int32` scalar, the token that marks end of decoding.
|
||||
seed: The sampling seed.
|
||||
softmax_temperature: (Optional) `float32` scalar, value to divide the
|
||||
logits by before computing the softmax. Larger values (above 1.0) result
|
||||
in more random samples, while smaller values push the sampling
|
||||
distribution towards the argmax. Must be strictly greater than 0.
|
||||
Defaults to 1.0.
|
||||
seed: (Optional) The sampling seed.
|
||||
|
||||
Raises:
|
||||
ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
|
||||
@ -537,6 +543,7 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
|
||||
"""
|
||||
super(SampleEmbeddingHelper, self).__init__(
|
||||
embedding, start_tokens, end_token)
|
||||
self._softmax_temperature = softmax_temperature
|
||||
self._seed = seed
|
||||
|
||||
def sample(self, time, outputs, state, name=None):
|
||||
@ -546,7 +553,12 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper):
|
||||
if not isinstance(outputs, ops.Tensor):
|
||||
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
|
||||
type(outputs))
|
||||
sample_id_sampler = categorical.Categorical(logits=outputs)
|
||||
if self._softmax_temperature is None:
|
||||
logits = outputs
|
||||
else:
|
||||
logits = outputs / self._softmax_temperature
|
||||
|
||||
sample_id_sampler = categorical.Categorical(logits=logits)
|
||||
sample_ids = sample_id_sampler.sample(seed=self._seed)
|
||||
|
||||
return sample_ids
|
||||
|
Loading…
Reference in New Issue
Block a user