Data type support for seq2seq attention mechanisms (#12007)
* Dtype support for Attention Wrapper added * Reviews * Converting string dtype to tf.dtype
This commit is contained in:
parent
1dab8efeaf
commit
b1ce37a0fe
@ -149,7 +149,7 @@ class _BaseAttentionMechanism(AttentionMechanism):
|
||||
memory_sequence_length=None,
|
||||
memory_layer=None,
|
||||
check_inner_dims_defined=True,
|
||||
score_mask_value=float("-inf"),
|
||||
score_mask_value=None,
|
||||
name=None):
|
||||
"""Construct base AttentionMechanism class.
|
||||
|
||||
@ -187,9 +187,12 @@ class _BaseAttentionMechanism(AttentionMechanism):
|
||||
"memory_layer is not a Layer: %s" % type(memory_layer).__name__)
|
||||
self._query_layer = query_layer
|
||||
self._memory_layer = memory_layer
|
||||
self.dtype = memory_layer.dtype
|
||||
if not callable(probability_fn):
|
||||
raise TypeError("probability_fn must be callable, saw type: %s" %
|
||||
type(probability_fn).__name__)
|
||||
if score_mask_value is None:
|
||||
score_mask_value = dtypes.as_dtype(self._memory_layer.dtype).as_numpy_dtype(-np.inf)
|
||||
self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda
|
||||
probability_fn(
|
||||
_maybe_mask_score(score, memory_sequence_length, score_mask_value),
|
||||
@ -334,7 +337,8 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
memory_sequence_length=None,
|
||||
scale=False,
|
||||
probability_fn=None,
|
||||
score_mask_value=float("-inf"),
|
||||
score_mask_value=None,
|
||||
dtype=None,
|
||||
name="LuongAttention"):
|
||||
"""Construct the AttentionMechanism mechanism.
|
||||
|
||||
@ -353,17 +357,20 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
score_mask_value: (optional) The mask value for score before passing into
|
||||
`probability_fn`. The default is -inf. Only used if
|
||||
`memory_sequence_length` is not None.
|
||||
dtype: The data type for the memory layer of the attention mechanism.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
# For LuongAttention, we only transform the memory layer; thus
|
||||
# num_units **must** match expected the query depth.
|
||||
if probability_fn is None:
|
||||
probability_fn = nn_ops.softmax
|
||||
if dtype is None:
|
||||
dtype = dtypes.float32
|
||||
wrapped_probability_fn = lambda score, _: probability_fn(score)
|
||||
super(LuongAttention, self).__init__(
|
||||
query_layer=None,
|
||||
memory_layer=layers_core.Dense(
|
||||
num_units, name="memory_layer", use_bias=False),
|
||||
num_units, name="memory_layer", use_bias=False, dtype=dtype),
|
||||
memory=memory,
|
||||
probability_fn=wrapped_probability_fn,
|
||||
memory_sequence_length=memory_sequence_length,
|
||||
@ -475,7 +482,8 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
memory_sequence_length=None,
|
||||
normalize=False,
|
||||
probability_fn=None,
|
||||
score_mask_value=float("-inf"),
|
||||
score_mask_value=None,
|
||||
dtype=None,
|
||||
name="BahdanauAttention"):
|
||||
"""Construct the Attention mechanism.
|
||||
|
||||
@ -494,16 +502,20 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
score_mask_value: (optional): The mask value for score before passing into
|
||||
`probability_fn`. The default is -inf. Only used if
|
||||
`memory_sequence_length` is not None.
|
||||
dtype: The data type for the query and memory layers of the attention
|
||||
mechanism.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
if probability_fn is None:
|
||||
probability_fn = nn_ops.softmax
|
||||
if dtype is None:
|
||||
dtype = dtypes.float32
|
||||
wrapped_probability_fn = lambda score, _: probability_fn(score)
|
||||
super(BahdanauAttention, self).__init__(
|
||||
query_layer=layers_core.Dense(
|
||||
num_units, name="query_layer", use_bias=False),
|
||||
num_units, name="query_layer", use_bias=False, dtype=dtype),
|
||||
memory_layer=layers_core.Dense(
|
||||
num_units, name="memory_layer", use_bias=False),
|
||||
num_units, name="memory_layer", use_bias=False, dtype=dtype),
|
||||
memory=memory,
|
||||
probability_fn=wrapped_probability_fn,
|
||||
memory_sequence_length=memory_sequence_length,
|
||||
@ -734,11 +746,12 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
|
||||
memory,
|
||||
memory_sequence_length=None,
|
||||
normalize=False,
|
||||
score_mask_value=float("-inf"),
|
||||
score_mask_value=None,
|
||||
sigmoid_noise=0.,
|
||||
sigmoid_noise_seed=None,
|
||||
score_bias_init=0.,
|
||||
mode="parallel",
|
||||
dtype=None,
|
||||
name="BahdanauMonotonicAttention"):
|
||||
"""Construct the Attention mechanism.
|
||||
|
||||
@ -762,17 +775,21 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
|
||||
mode: How to compute the attention distribution. Must be one of
|
||||
'recursive', 'parallel', or 'hard'. See the docstring for
|
||||
`tf.contrib.seq2seq.monotonic_attention` for more information.
|
||||
dtype: The data type for the query and memory layers of the attention
|
||||
mechanism.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
# Set up the monotonic probability fn with supplied parameters
|
||||
if dtype is None:
|
||||
dtype = dtypes.float32
|
||||
wrapped_probability_fn = functools.partial(
|
||||
_monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
|
||||
seed=sigmoid_noise_seed)
|
||||
super(BahdanauMonotonicAttention, self).__init__(
|
||||
query_layer=layers_core.Dense(
|
||||
num_units, name="query_layer", use_bias=False),
|
||||
num_units, name="query_layer", use_bias=False, dtype=dtype),
|
||||
memory_layer=layers_core.Dense(
|
||||
num_units, name="memory_layer", use_bias=False),
|
||||
num_units, name="memory_layer", use_bias=False, dtype=dtype),
|
||||
memory=memory,
|
||||
probability_fn=wrapped_probability_fn,
|
||||
memory_sequence_length=memory_sequence_length,
|
||||
@ -830,11 +847,12 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
|
||||
memory,
|
||||
memory_sequence_length=None,
|
||||
scale=False,
|
||||
score_mask_value=float("-inf"),
|
||||
score_mask_value=None,
|
||||
sigmoid_noise=0.,
|
||||
sigmoid_noise_seed=None,
|
||||
score_bias_init=0.,
|
||||
mode="parallel",
|
||||
dtype=None,
|
||||
name="LuongMonotonicAttention"):
|
||||
"""Construct the Attention mechanism.
|
||||
|
||||
@ -858,17 +876,21 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
|
||||
mode: How to compute the attention distribution. Must be one of
|
||||
'recursive', 'parallel', or 'hard'. See the docstring for
|
||||
`tf.contrib.seq2seq.monotonic_attention` for more information.
|
||||
dtype: The data type for the query and memory layers of the attention
|
||||
mechanism.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
# Set up the monotonic probability fn with supplied parameters
|
||||
if dtype is None:
|
||||
dtype = dtypes.float32
|
||||
wrapped_probability_fn = functools.partial(
|
||||
_monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
|
||||
seed=sigmoid_noise_seed)
|
||||
super(LuongMonotonicAttention, self).__init__(
|
||||
query_layer=layers_core.Dense(
|
||||
num_units, name="query_layer", use_bias=False),
|
||||
num_units, name="query_layer", use_bias=False, dtype=dtype),
|
||||
memory_layer=layers_core.Dense(
|
||||
num_units, name="memory_layer", use_bias=False),
|
||||
num_units, name="memory_layer", use_bias=False, dtype=dtype),
|
||||
memory=memory,
|
||||
probability_fn=wrapped_probability_fn,
|
||||
memory_sequence_length=memory_sequence_length,
|
||||
@ -1119,8 +1141,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||
% (len(attention_layer_sizes), len(attention_mechanisms)))
|
||||
self._attention_layers = tuple(
|
||||
layers_core.Dense(
|
||||
attention_layer_size, name="attention_layer", use_bias=False)
|
||||
for attention_layer_size in attention_layer_sizes)
|
||||
attention_layer_size, name="attention_layer", use_bias=False,
|
||||
dtype=attention_mechanisms[i].dtype)
|
||||
for i, attention_layer_size in enumerate(attention_layer_sizes))
|
||||
self._attention_layer_size = sum(attention_layer_sizes)
|
||||
else:
|
||||
self._attention_layers = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user