[TF contrib seq2seq] Cleanup normalized attention documentation and API

(removes the unnecessary r offset; since with softmax attention it makes no
difference)
Change: 151636533
This commit is contained in:
Eugene Brevdo 2017-03-29 14:43:04 -08:00 committed by TensorFlower Gardener
parent 1e97ff5f29
commit bc9ee91499
2 changed files with 116 additions and 141 deletions
tensorflow/contrib/seq2seq/python

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import sys
import functools
import numpy as np
@ -138,8 +139,10 @@ class AttentionWrapperTest(test.TestCase):
print("Copy/paste (%s)\nexpected_final_output = " % name,
sess_results["final_outputs"])
sys.stdout.flush()
print("Copy/paste (%s)\nexpected_final_state = " % name,
sess_results["final_state"])
sys.stdout.flush()
nest.map_structure(self.assertAllClose, expected_final_output,
sess_results["final_outputs"])
nest.map_structure(self.assertAllClose, expected_final_state,
@ -151,7 +154,7 @@ class AttentionWrapperTest(test.TestCase):
np.transpose(sess_results["final_outputs"].rnn_output,
(1, 0, 2)))
def testBahndahauNotNormalized(self):
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention
expected_final_output = BasicDecoderOutput(
@ -281,11 +284,11 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output,
expected_final_state,
attention_history=True,
name="testBahndahauNotNormalized")
name="testBahdanauNotNormalized")
def testBahndahauNormalized(self):
def testBahdanauNormalized(self):
create_attention_mechanism = functools.partial(
wrapper.BahdanauAttention, normalize=True, attention_r_initializer=2.0)
wrapper.BahdanauAttention, normalize=True)
expected_final_output = BasicDecoderOutput(
rnn_output=array(
@ -293,38 +296,38 @@ class AttentionWrapperTest(test.TestCase):
6.64783875e-03, 2.94425711e-03, 5.26542449e-03, -2.64955591e-03,
-7.95925129e-03, -5.02286293e-03
], [
7.01954216e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03,
-7.11239828e-03, -4.72904276e-03
7.01954123e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03,
-7.11239874e-03, -4.72904276e-03
], [
6.62360759e-03, 3.12234741e-03, 5.51807601e-03, -2.46222410e-03,
-7.40198744e-03, -4.85700974e-03
6.62360899e-03, 3.12234787e-03, 5.51807694e-03, -2.46222341e-03,
-7.40198931e-03, -4.85701021e-03
]], [[
7.37590110e-03, -1.02620223e-03, 3.61374998e-03,
-5.74620720e-03, 5.05625363e-03, -7.45209027e-03
7.37589924e-03, -1.02620223e-03, 3.61374952e-03,
-5.74620720e-03, 5.05625410e-03, -7.45209027e-03
], [
7.61946291e-03, -1.09287433e-03, 3.78817227e-03,
-5.78709692e-03, 4.56611067e-03, -7.96987675e-03
7.61946291e-03, -1.09287468e-03, 3.78817180e-03,
-5.78709645e-03, 4.56611114e-03, -7.96987582e-03
], [
7.69207673e-03, -1.26582675e-03, 3.85218766e-03,
-5.81111759e-03, 4.63287160e-03, -7.86337163e-03
7.69207766e-03, -1.26582675e-03, 3.85218812e-03,
-5.81111759e-03, 4.63287206e-03, -7.86337163e-03
]], [[
-2.69413716e-03, 3.47182853e-04, -1.82145915e-03,
-1.39805046e-03, -8.05486552e-03, -1.08372122e-02
-2.69413739e-03, 3.47183552e-04, -1.82145904e-03,
-1.39805069e-03, -8.05486552e-03, -1.08372131e-02
], [
-2.70848861e-03, 3.03293811e-04, -1.67230773e-03,
-1.34555507e-03, -8.40565283e-03, -1.10935066e-02
-2.70848931e-03, 3.03293345e-04, -1.67230750e-03,
-1.34555507e-03, -8.40565283e-03, -1.10935047e-02
], [
-2.47822376e-03, 5.79409534e-04, -1.70188234e-03,
-1.42583461e-03, -7.81180616e-03, -1.10740727e-02
-2.47822329e-03, 5.79408603e-04, -1.70188327e-03,
-1.42583530e-03, -7.81180616e-03, -1.10740755e-02
]], [[
1.48582947e-03, -3.88786104e-03, -9.39912978e-04,
8.36255029e-04, -1.28223014e-03, 6.40908210e-03
], [
1.18177070e-03, -4.47923224e-03, -1.05711189e-03,
8.12121376e-04, -2.08477280e-03, 6.27523381e-03
1.18177081e-03, -4.47923271e-03, -1.05711201e-03,
8.12121783e-04, -2.08477327e-03, 6.27523474e-03
], [
9.49664915e-04, -4.28958004e-03, -1.39053748e-03,
6.29657647e-04, -2.14899099e-03, 6.52727857e-03
9.49664740e-04, -4.28957958e-03, -1.39053771e-03,
6.29657647e-04, -2.14899099e-03, 6.52727811e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
@ -396,8 +399,8 @@ class AttentionWrapperTest(test.TestCase):
0.00769208, -0.00126583, 0.00385219, -0.00581112, 0.00463287,
-0.00786337
], [
-0.00247822, 0.00057941, -0.00170188, -0.00142583, -0.00781181,
-0.01107407
-0.00247822, 0.00057941, -0.00170188, -0.00142584, -0.00781181,
-0.01107408
], [
0.00094966, -0.00428958, -0.00139054, 0.00062966, -0.00214899,
0.00652728
@ -413,7 +416,7 @@ class AttentionWrapperTest(test.TestCase):
create_attention_mechanism,
expected_final_output,
expected_final_state,
name="testBahndahauNormalized")
name="testBahdanauNormalized")
def testLuongNotNormalized(self):
create_attention_mechanism = wrapper.LuongAttention
@ -547,48 +550,48 @@ class AttentionWrapperTest(test.TestCase):
attention_mechanism_depth=9,
name="testLuongNotNormalized")
def testLuongNormalized(self):
def testLuongScaled(self):
create_attention_mechanism = functools.partial(
wrapper.LuongAttention, normalize=True, attention_r_initializer=2.0)
wrapper.LuongAttention, scale=True)
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.77905359e-03, 1.90462871e-03, 2.03362387e-03, -3.82418092e-03,
-4.29544738e-03, -6.42893929e-03
1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03,
-4.39571124e-03, -6.32379763e-03
], [
2.22176919e-03, 2.00074376e-03, 1.96245289e-03, -3.89029179e-03,
-3.40791186e-03, -6.18315209e-03
2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03,
-3.42792575e-03, -6.17497414e-03
], [
1.73094252e-03, 2.02626130e-03, 2.17241934e-03, -3.71918478e-03,
-3.61261726e-03, -6.38988335e-03
1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03,
-3.46369296e-03, -6.54224353e-03
]], [[
4.82460577e-03, -1.97555218e-03, 3.31952656e-03,
-2.73223547e-03, 3.59055214e-03, -4.72807605e-03
4.77780215e-03, -1.98677275e-03, 3.30950436e-03,
-2.68179504e-03, 3.56271653e-03, -4.67860466e-03
], [
5.10502933e-03, -2.03409232e-03, 3.50220758e-03,
-2.81245424e-03, 3.12219234e-03, -5.28475922e-03
5.13039157e-03, -2.02797214e-03, 3.50760575e-03,
-2.83981953e-03, 3.13726603e-03, -5.31156827e-03
], [
5.16733387e-03, -2.19496130e-03, 3.55309760e-03,
-2.83683022e-03, 3.18047474e-03, -5.16855856e-03
5.17205056e-03, -2.16446724e-03, 3.53219034e-03,
-2.86490913e-03, 3.17879021e-03, -5.17592067e-03
]], [[
-1.34300231e-03, -8.74995370e-04, -2.08824431e-03,
-1.83634064e-03, -5.20561496e-03, -9.02957655e-03
-1.38538703e-03, -6.40910701e-04, -2.02864106e-03,
-1.79018872e-03, -5.18789608e-03, -8.95875692e-03
], [
-1.35285407e-03, -9.54896386e-04, -1.95000600e-03,
-1.79241551e-03, -5.56529313e-03, -9.30588506e-03
-1.38620089e-03, -7.92010222e-04, -1.91070826e-03,
-1.76206254e-03, -5.56525169e-03, -9.27332044e-03
], [
-1.11046492e-03, -7.09228800e-04, -1.98852271e-03,
-1.88147742e-03, -4.95750736e-03, -9.28789563e-03
-1.11966045e-03, -6.07630936e-04, -1.96643686e-03,
-1.86803937e-03, -4.93048411e-03, -9.25842486e-03
]], [[
1.52318960e-03, -3.95966647e-03, -9.61032696e-04,
8.49176315e-04, -1.29571836e-03, 6.54178672e-03
1.50820788e-03, -3.93087184e-03, -9.52563598e-04,
8.43994785e-04, -1.29030924e-03, 6.48857141e-03
], [
1.20770617e-03, -4.52907849e-03, -1.07177906e-03,
8.21074296e-04, -2.09413515e-03, 6.36733975e-03
1.17029145e-03, -4.45716921e-03, -1.05062663e-03,
8.08141369e-04, -2.08062865e-03, 6.23444980e-03
], [
9.85645805e-04, -4.35873261e-03, -1.41088583e-03,
6.42077066e-04, -2.16197851e-03, 6.65505463e-03
9.67921398e-04, -4.32466762e-03, -1.40085898e-03,
6.35969569e-04, -2.15558149e-03, 6.59212377e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
@ -608,21 +611,21 @@ class AttentionWrapperTest(test.TestCase):
cell_state=LSTMStateTuple(
c=array(
[[
-2.18956061e-02, -8.04580562e-03, -1.48322619e-03,
1.61062591e-02, -1.37978457e-02, -7.57943699e-03,
-8.28466844e-03, -1.18735433e-02, 1.78835411e-02
-2.18960866e-02, -8.04429129e-03, -1.48267671e-03,
1.61071159e-02, -1.37981661e-02, -7.57933082e-03,
-8.28570686e-03, -1.18733812e-02, 1.78834442e-02
], [
1.74204111e-02, -1.41935302e-02, -3.88075318e-03,
3.19713354e-02, -3.54694575e-02, -2.14688908e-02,
-6.21729670e-03, -1.69236294e-03, -1.94492973e-02
1.74204130e-02, -1.41935758e-02, -3.88074201e-03,
3.19713727e-02, -3.54694910e-02, -2.14688145e-02,
-6.21731905e-03, -1.69229065e-03, -1.94492843e-02
], [
-1.14519196e-02, 8.77891947e-03, -1.62970666e-02,
-1.39963832e-02, 1.34855332e-02, -1.04490370e-02,
6.16136147e-03, -9.41039063e-03, -6.57598302e-03
-1.14494488e-02, 8.77974741e-03, -1.62960067e-02,
-1.39961652e-02, 1.34879015e-02, -1.04502086e-02,
6.15879148e-03, -9.40956455e-03, -6.57592434e-03
], [
-4.74749245e-02, -1.19127585e-02, -7.39064999e-05,
4.10550945e-02, -1.36729144e-03, 2.11788230e-02,
-2.80466378e-02, -5.44511043e-02, -2.91905347e-02
-4.74739634e-02, -1.19136050e-02, -7.36759976e-05,
4.10547927e-02, -1.36767328e-03, 2.11772677e-02,
-2.80479677e-02, -5.44514805e-02, -2.91903690e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
@ -631,21 +634,21 @@ class AttentionWrapperTest(test.TestCase):
dtype=float32),
h=array(
[[
-1.09837065e-02, -3.97554692e-03, -7.54752022e-04,
7.91159924e-03, -7.02159712e-03, -3.80695192e-03,
-4.22011921e-03, -6.05454855e-03, 8.92061647e-03
-1.09839402e-02, -3.97479767e-03, -7.54472159e-04,
7.91201927e-03, -7.02175125e-03, -3.80689627e-03,
-4.22065007e-03, -6.05447078e-03, 8.92056432e-03
], [
8.68127029e-03, -7.16967974e-03, -1.88376172e-03,
1.62681639e-02, -1.76830329e-02, -1.06617883e-02,
-3.07535171e-03, -8.45587580e-04, -9.99377016e-03
8.68127123e-03, -7.16970162e-03, -1.88375649e-03,
1.62681788e-02, -1.76830534e-02, -1.06617520e-02,
-3.07536125e-03, -8.45551898e-04, -9.99375992e-03
], [
-5.71158715e-03, 4.50087292e-03, -8.07643402e-03,
-6.94846408e-03, 6.75802259e-03, -5.12090744e-03,
3.06212017e-03, -4.61751036e-03, -3.23935202e-03
-5.71034756e-03, 4.50129062e-03, -8.07590690e-03,
-6.94835978e-03, 6.75921654e-03, -5.12148207e-03,
3.06083867e-03, -4.61710012e-03, -3.23932176e-03
], [
-2.37229262e-02, -5.88545902e-03, -3.71685164e-05,
2.01788824e-02, -6.75938500e-04, 1.06682768e-02,
-1.42627759e-02, -2.69629639e-02, -1.45033952e-02
-2.37224493e-02, -5.88587578e-03, -3.70525813e-05,
2.01787278e-02, -6.76127791e-04, 1.06675029e-02,
-1.42634306e-02, -2.69631632e-02, -1.45033058e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
@ -654,17 +657,17 @@ class AttentionWrapperTest(test.TestCase):
dtype=float32)),
attention=array(
[[
0.00173094, 0.00202626, 0.00217242, -0.00371918, -0.00361262,
-0.00638988
0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369,
-0.00654224
], [
0.00516733, -0.00219496, 0.0035531, -0.00283683, 0.00318047,
-0.00516856
0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879,
-0.00517592
], [
-0.00111046, -0.00070923, -0.00198852, -0.00188148, -0.00495751,
-0.0092879
-0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048,
-0.00925842
], [
0.00098565, -0.00435873, -0.00141089, 0.00064208, -0.00216198,
0.00665505
0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558,
0.00659212
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
@ -678,7 +681,7 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output,
expected_final_state,
attention_mechanism_depth=9,
name="testLuongNormalized")
name="testLuongScaled")
if __name__ == "__main__":

View File

@ -176,19 +176,18 @@ class LuongAttention(_BaseAttentionMechanism):
"Effective Approaches to Attention-based Neural Machine Translation."
EMNLP 2015. https://arxiv.org/abs/1508.04025
The second is the normalized form. This form is inspired by the
normalization proposed for Bahdanau attention in
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
(Eq. 15).
The second is the scaled form inspired partly by the normalized form of
Bahdanau attention.
To enable the second form, construct the object with parameter
`normalize=True`.
`scale=True`.
"""
def __init__(self, num_units, memory, memory_sequence_length=None,
normalize=False, attention_r_initializer=None,
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
scale=False,
name="LuongAttention"):
"""Construct the AttentionMechanism mechanism.
@ -199,9 +198,7 @@ class LuongAttention(_BaseAttentionMechanism):
memory_sequence_length (optional): Sequence lengths for the batch entries
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term.
attention_r_initializer: Initial value of the post-normalization bias
when normalizing. Default is `0`.
scale: Python boolean. Whether to scale the energy term.
name: Name to use when creating ops.
"""
# For LuongAttention, we only transform the memory layer; thus
@ -214,18 +211,8 @@ class LuongAttention(_BaseAttentionMechanism):
memory_sequence_length=memory_sequence_length,
name=name)
self._num_units = num_units
self._normalize = normalize
self._scale = scale
self._name = name
if normalize and attention_r_initializer is None:
attention_r_initializer = 0
if normalize:
with ops.name_scope(
name, "LuongAttentionInit",
[memory, attention_r_initializer]):
attention_r_initializer = ops.convert_to_tensor(
attention_r_initializer, dtype=self.values.dtype,
name="attention_r_initializer")
self._attention_r_initializer = attention_r_initializer
def __call__(self, query):
"""Score the query based on the keys and values.
@ -268,16 +255,11 @@ class LuongAttention(_BaseAttentionMechanism):
score = math_ops.matmul(query, self.keys, transpose_b=True)
score = array_ops.squeeze(score, [1])
if self._normalize:
# Scalar used in weight normalization
if self._scale:
# Scalar used in weight scaling
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
initializer=math.sqrt((1. / self._num_units)))
# Scalar bias added to attention scores
r = variable_scope.get_variable(
"attention_r", dtype=dtype,
initializer=self._attention_r_initializer)
score = g * score + r
"attention_g", dtype=dtype, initializer=1.)
score = g * score
return score
@ -292,18 +274,23 @@ class BahdanauAttention(_BaseAttentionMechanism):
"Neural Machine Translation by Jointly Learning to Align and Translate."
ICLR 2015. https://arxiv.org/abs/1409.0473
The second is the normalized form, Raffel attention, as described in:
The second is the normalized form. This form is inspired by the
weight normalization article:
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
(Eq. 15).
Tim Salimans, Diederik P. Kingma.
"Weight Normalization: A Simple Reparameterization to Accelerate
Training of Deep Neural Networks."
https://arxiv.org/abs/1602.07868
To enable the second form, construct the object with parameter
`normalize=True`.
"""
def __init__(self, num_units, memory, memory_sequence_length=None,
normalize=False, attention_r_initializer=None,
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
normalize=False,
name="BahdanauAttention"):
"""Construct the Attention mechanism.
@ -315,8 +302,6 @@ class BahdanauAttention(_BaseAttentionMechanism):
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term.
attention_r_initializer: Initial value of the post-normalization bias
when normalizing. Default is `0`.
name: Name to use when creating ops.
"""
super(BahdanauAttention, self).__init__(
@ -330,15 +315,6 @@ class BahdanauAttention(_BaseAttentionMechanism):
self._num_units = num_units
self._normalize = normalize
self._name = name
if normalize and attention_r_initializer is None:
attention_r_initializer = 0
if normalize:
with ops.name_scope(name, "BahdanauAttentionInit",
[memory, attention_r_initializer]):
attention_r_initializer = ops.convert_to_tensor(
attention_r_initializer, dtype=self.values.dtype,
name="attention_r_initializer")
self._attention_r_initializer = attention_r_initializer
def __call__(self, query):
"""Score the query based on the keys and values.
@ -350,7 +326,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
score: Tensor of dtype matching `self.values` and shape
`[batch_size, self.num_units]`.
"""
with variable_scope.variable_scope(None, "bahndahau_attention", [query]):
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
dtype = processed_query.dtype
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
@ -366,15 +342,11 @@ class BahdanauAttention(_BaseAttentionMechanism):
b = variable_scope.get_variable(
"attention_b", [self._num_units], dtype=dtype,
initializer=init_ops.zeros_initializer())
# Scalar bias added to attention scores
r = variable_scope.get_variable(
"attention_r", dtype=dtype,
initializer=self._attention_r_initializer)
# normed_v = g * v / ||v||
normed_v = g * v * math_ops.rsqrt(
math_ops.reduce_sum(math_ops.square(v)))
score = math_ops.reduce_sum(
normed_v * math_ops.tanh(self.keys + processed_query + b), [2]) + r
normed_v * math_ops.tanh(self.keys + processed_query + b), [2])
else:
score = math_ops.reduce_sum(
v * math_ops.tanh(self.keys + processed_query), [2])