[TF contrib seq2seq] Cleanup normalized attention documentation and API
(removes the unnecessary r offset; since with softmax attention it makes no difference) Change: 151636533
This commit is contained in:
parent
1e97ff5f29
commit
bc9ee91499
tensorflow/contrib/seq2seq/python
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import sys
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
@ -138,8 +139,10 @@ class AttentionWrapperTest(test.TestCase):
|
||||
|
||||
print("Copy/paste (%s)\nexpected_final_output = " % name,
|
||||
sess_results["final_outputs"])
|
||||
sys.stdout.flush()
|
||||
print("Copy/paste (%s)\nexpected_final_state = " % name,
|
||||
sess_results["final_state"])
|
||||
sys.stdout.flush()
|
||||
nest.map_structure(self.assertAllClose, expected_final_output,
|
||||
sess_results["final_outputs"])
|
||||
nest.map_structure(self.assertAllClose, expected_final_state,
|
||||
@ -151,7 +154,7 @@ class AttentionWrapperTest(test.TestCase):
|
||||
np.transpose(sess_results["final_outputs"].rnn_output,
|
||||
(1, 0, 2)))
|
||||
|
||||
def testBahndahauNotNormalized(self):
|
||||
def testBahdanauNotNormalized(self):
|
||||
create_attention_mechanism = wrapper.BahdanauAttention
|
||||
|
||||
expected_final_output = BasicDecoderOutput(
|
||||
@ -281,11 +284,11 @@ class AttentionWrapperTest(test.TestCase):
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
attention_history=True,
|
||||
name="testBahndahauNotNormalized")
|
||||
name="testBahdanauNotNormalized")
|
||||
|
||||
def testBahndahauNormalized(self):
|
||||
def testBahdanauNormalized(self):
|
||||
create_attention_mechanism = functools.partial(
|
||||
wrapper.BahdanauAttention, normalize=True, attention_r_initializer=2.0)
|
||||
wrapper.BahdanauAttention, normalize=True)
|
||||
|
||||
expected_final_output = BasicDecoderOutput(
|
||||
rnn_output=array(
|
||||
@ -293,38 +296,38 @@ class AttentionWrapperTest(test.TestCase):
|
||||
6.64783875e-03, 2.94425711e-03, 5.26542449e-03, -2.64955591e-03,
|
||||
-7.95925129e-03, -5.02286293e-03
|
||||
], [
|
||||
7.01954216e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03,
|
||||
-7.11239828e-03, -4.72904276e-03
|
||||
7.01954123e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03,
|
||||
-7.11239874e-03, -4.72904276e-03
|
||||
], [
|
||||
6.62360759e-03, 3.12234741e-03, 5.51807601e-03, -2.46222410e-03,
|
||||
-7.40198744e-03, -4.85700974e-03
|
||||
6.62360899e-03, 3.12234787e-03, 5.51807694e-03, -2.46222341e-03,
|
||||
-7.40198931e-03, -4.85701021e-03
|
||||
]], [[
|
||||
7.37590110e-03, -1.02620223e-03, 3.61374998e-03,
|
||||
-5.74620720e-03, 5.05625363e-03, -7.45209027e-03
|
||||
7.37589924e-03, -1.02620223e-03, 3.61374952e-03,
|
||||
-5.74620720e-03, 5.05625410e-03, -7.45209027e-03
|
||||
], [
|
||||
7.61946291e-03, -1.09287433e-03, 3.78817227e-03,
|
||||
-5.78709692e-03, 4.56611067e-03, -7.96987675e-03
|
||||
7.61946291e-03, -1.09287468e-03, 3.78817180e-03,
|
||||
-5.78709645e-03, 4.56611114e-03, -7.96987582e-03
|
||||
], [
|
||||
7.69207673e-03, -1.26582675e-03, 3.85218766e-03,
|
||||
-5.81111759e-03, 4.63287160e-03, -7.86337163e-03
|
||||
7.69207766e-03, -1.26582675e-03, 3.85218812e-03,
|
||||
-5.81111759e-03, 4.63287206e-03, -7.86337163e-03
|
||||
]], [[
|
||||
-2.69413716e-03, 3.47182853e-04, -1.82145915e-03,
|
||||
-1.39805046e-03, -8.05486552e-03, -1.08372122e-02
|
||||
-2.69413739e-03, 3.47183552e-04, -1.82145904e-03,
|
||||
-1.39805069e-03, -8.05486552e-03, -1.08372131e-02
|
||||
], [
|
||||
-2.70848861e-03, 3.03293811e-04, -1.67230773e-03,
|
||||
-1.34555507e-03, -8.40565283e-03, -1.10935066e-02
|
||||
-2.70848931e-03, 3.03293345e-04, -1.67230750e-03,
|
||||
-1.34555507e-03, -8.40565283e-03, -1.10935047e-02
|
||||
], [
|
||||
-2.47822376e-03, 5.79409534e-04, -1.70188234e-03,
|
||||
-1.42583461e-03, -7.81180616e-03, -1.10740727e-02
|
||||
-2.47822329e-03, 5.79408603e-04, -1.70188327e-03,
|
||||
-1.42583530e-03, -7.81180616e-03, -1.10740755e-02
|
||||
]], [[
|
||||
1.48582947e-03, -3.88786104e-03, -9.39912978e-04,
|
||||
8.36255029e-04, -1.28223014e-03, 6.40908210e-03
|
||||
], [
|
||||
1.18177070e-03, -4.47923224e-03, -1.05711189e-03,
|
||||
8.12121376e-04, -2.08477280e-03, 6.27523381e-03
|
||||
1.18177081e-03, -4.47923271e-03, -1.05711201e-03,
|
||||
8.12121783e-04, -2.08477327e-03, 6.27523474e-03
|
||||
], [
|
||||
9.49664915e-04, -4.28958004e-03, -1.39053748e-03,
|
||||
6.29657647e-04, -2.14899099e-03, 6.52727857e-03
|
||||
9.49664740e-04, -4.28957958e-03, -1.39053771e-03,
|
||||
6.29657647e-04, -2.14899099e-03, 6.52727811e-03
|
||||
]], [[
|
||||
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
|
||||
-1.56512906e-04, 9.63474595e-05
|
||||
@ -396,8 +399,8 @@ class AttentionWrapperTest(test.TestCase):
|
||||
0.00769208, -0.00126583, 0.00385219, -0.00581112, 0.00463287,
|
||||
-0.00786337
|
||||
], [
|
||||
-0.00247822, 0.00057941, -0.00170188, -0.00142583, -0.00781181,
|
||||
-0.01107407
|
||||
-0.00247822, 0.00057941, -0.00170188, -0.00142584, -0.00781181,
|
||||
-0.01107408
|
||||
], [
|
||||
0.00094966, -0.00428958, -0.00139054, 0.00062966, -0.00214899,
|
||||
0.00652728
|
||||
@ -413,7 +416,7 @@ class AttentionWrapperTest(test.TestCase):
|
||||
create_attention_mechanism,
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
name="testBahndahauNormalized")
|
||||
name="testBahdanauNormalized")
|
||||
|
||||
def testLuongNotNormalized(self):
|
||||
create_attention_mechanism = wrapper.LuongAttention
|
||||
@ -547,48 +550,48 @@ class AttentionWrapperTest(test.TestCase):
|
||||
attention_mechanism_depth=9,
|
||||
name="testLuongNotNormalized")
|
||||
|
||||
def testLuongNormalized(self):
|
||||
def testLuongScaled(self):
|
||||
create_attention_mechanism = functools.partial(
|
||||
wrapper.LuongAttention, normalize=True, attention_r_initializer=2.0)
|
||||
wrapper.LuongAttention, scale=True)
|
||||
|
||||
expected_final_output = BasicDecoderOutput(
|
||||
rnn_output=array(
|
||||
[[[
|
||||
1.77905359e-03, 1.90462871e-03, 2.03362387e-03, -3.82418092e-03,
|
||||
-4.29544738e-03, -6.42893929e-03
|
||||
1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03,
|
||||
-4.39571124e-03, -6.32379763e-03
|
||||
], [
|
||||
2.22176919e-03, 2.00074376e-03, 1.96245289e-03, -3.89029179e-03,
|
||||
-3.40791186e-03, -6.18315209e-03
|
||||
2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03,
|
||||
-3.42792575e-03, -6.17497414e-03
|
||||
], [
|
||||
1.73094252e-03, 2.02626130e-03, 2.17241934e-03, -3.71918478e-03,
|
||||
-3.61261726e-03, -6.38988335e-03
|
||||
1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03,
|
||||
-3.46369296e-03, -6.54224353e-03
|
||||
]], [[
|
||||
4.82460577e-03, -1.97555218e-03, 3.31952656e-03,
|
||||
-2.73223547e-03, 3.59055214e-03, -4.72807605e-03
|
||||
4.77780215e-03, -1.98677275e-03, 3.30950436e-03,
|
||||
-2.68179504e-03, 3.56271653e-03, -4.67860466e-03
|
||||
], [
|
||||
5.10502933e-03, -2.03409232e-03, 3.50220758e-03,
|
||||
-2.81245424e-03, 3.12219234e-03, -5.28475922e-03
|
||||
5.13039157e-03, -2.02797214e-03, 3.50760575e-03,
|
||||
-2.83981953e-03, 3.13726603e-03, -5.31156827e-03
|
||||
], [
|
||||
5.16733387e-03, -2.19496130e-03, 3.55309760e-03,
|
||||
-2.83683022e-03, 3.18047474e-03, -5.16855856e-03
|
||||
5.17205056e-03, -2.16446724e-03, 3.53219034e-03,
|
||||
-2.86490913e-03, 3.17879021e-03, -5.17592067e-03
|
||||
]], [[
|
||||
-1.34300231e-03, -8.74995370e-04, -2.08824431e-03,
|
||||
-1.83634064e-03, -5.20561496e-03, -9.02957655e-03
|
||||
-1.38538703e-03, -6.40910701e-04, -2.02864106e-03,
|
||||
-1.79018872e-03, -5.18789608e-03, -8.95875692e-03
|
||||
], [
|
||||
-1.35285407e-03, -9.54896386e-04, -1.95000600e-03,
|
||||
-1.79241551e-03, -5.56529313e-03, -9.30588506e-03
|
||||
-1.38620089e-03, -7.92010222e-04, -1.91070826e-03,
|
||||
-1.76206254e-03, -5.56525169e-03, -9.27332044e-03
|
||||
], [
|
||||
-1.11046492e-03, -7.09228800e-04, -1.98852271e-03,
|
||||
-1.88147742e-03, -4.95750736e-03, -9.28789563e-03
|
||||
-1.11966045e-03, -6.07630936e-04, -1.96643686e-03,
|
||||
-1.86803937e-03, -4.93048411e-03, -9.25842486e-03
|
||||
]], [[
|
||||
1.52318960e-03, -3.95966647e-03, -9.61032696e-04,
|
||||
8.49176315e-04, -1.29571836e-03, 6.54178672e-03
|
||||
1.50820788e-03, -3.93087184e-03, -9.52563598e-04,
|
||||
8.43994785e-04, -1.29030924e-03, 6.48857141e-03
|
||||
], [
|
||||
1.20770617e-03, -4.52907849e-03, -1.07177906e-03,
|
||||
8.21074296e-04, -2.09413515e-03, 6.36733975e-03
|
||||
1.17029145e-03, -4.45716921e-03, -1.05062663e-03,
|
||||
8.08141369e-04, -2.08062865e-03, 6.23444980e-03
|
||||
], [
|
||||
9.85645805e-04, -4.35873261e-03, -1.41088583e-03,
|
||||
6.42077066e-04, -2.16197851e-03, 6.65505463e-03
|
||||
9.67921398e-04, -4.32466762e-03, -1.40085898e-03,
|
||||
6.35969569e-04, -2.15558149e-03, 6.59212377e-03
|
||||
]], [[
|
||||
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
|
||||
-1.56512906e-04, 9.63474595e-05
|
||||
@ -608,21 +611,21 @@ class AttentionWrapperTest(test.TestCase):
|
||||
cell_state=LSTMStateTuple(
|
||||
c=array(
|
||||
[[
|
||||
-2.18956061e-02, -8.04580562e-03, -1.48322619e-03,
|
||||
1.61062591e-02, -1.37978457e-02, -7.57943699e-03,
|
||||
-8.28466844e-03, -1.18735433e-02, 1.78835411e-02
|
||||
-2.18960866e-02, -8.04429129e-03, -1.48267671e-03,
|
||||
1.61071159e-02, -1.37981661e-02, -7.57933082e-03,
|
||||
-8.28570686e-03, -1.18733812e-02, 1.78834442e-02
|
||||
], [
|
||||
1.74204111e-02, -1.41935302e-02, -3.88075318e-03,
|
||||
3.19713354e-02, -3.54694575e-02, -2.14688908e-02,
|
||||
-6.21729670e-03, -1.69236294e-03, -1.94492973e-02
|
||||
1.74204130e-02, -1.41935758e-02, -3.88074201e-03,
|
||||
3.19713727e-02, -3.54694910e-02, -2.14688145e-02,
|
||||
-6.21731905e-03, -1.69229065e-03, -1.94492843e-02
|
||||
], [
|
||||
-1.14519196e-02, 8.77891947e-03, -1.62970666e-02,
|
||||
-1.39963832e-02, 1.34855332e-02, -1.04490370e-02,
|
||||
6.16136147e-03, -9.41039063e-03, -6.57598302e-03
|
||||
-1.14494488e-02, 8.77974741e-03, -1.62960067e-02,
|
||||
-1.39961652e-02, 1.34879015e-02, -1.04502086e-02,
|
||||
6.15879148e-03, -9.40956455e-03, -6.57592434e-03
|
||||
], [
|
||||
-4.74749245e-02, -1.19127585e-02, -7.39064999e-05,
|
||||
4.10550945e-02, -1.36729144e-03, 2.11788230e-02,
|
||||
-2.80466378e-02, -5.44511043e-02, -2.91905347e-02
|
||||
-4.74739634e-02, -1.19136050e-02, -7.36759976e-05,
|
||||
4.10547927e-02, -1.36767328e-03, 2.11772677e-02,
|
||||
-2.80479677e-02, -5.44514805e-02, -2.91903690e-02
|
||||
], [
|
||||
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
|
||||
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
|
||||
@ -631,21 +634,21 @@ class AttentionWrapperTest(test.TestCase):
|
||||
dtype=float32),
|
||||
h=array(
|
||||
[[
|
||||
-1.09837065e-02, -3.97554692e-03, -7.54752022e-04,
|
||||
7.91159924e-03, -7.02159712e-03, -3.80695192e-03,
|
||||
-4.22011921e-03, -6.05454855e-03, 8.92061647e-03
|
||||
-1.09839402e-02, -3.97479767e-03, -7.54472159e-04,
|
||||
7.91201927e-03, -7.02175125e-03, -3.80689627e-03,
|
||||
-4.22065007e-03, -6.05447078e-03, 8.92056432e-03
|
||||
], [
|
||||
8.68127029e-03, -7.16967974e-03, -1.88376172e-03,
|
||||
1.62681639e-02, -1.76830329e-02, -1.06617883e-02,
|
||||
-3.07535171e-03, -8.45587580e-04, -9.99377016e-03
|
||||
8.68127123e-03, -7.16970162e-03, -1.88375649e-03,
|
||||
1.62681788e-02, -1.76830534e-02, -1.06617520e-02,
|
||||
-3.07536125e-03, -8.45551898e-04, -9.99375992e-03
|
||||
], [
|
||||
-5.71158715e-03, 4.50087292e-03, -8.07643402e-03,
|
||||
-6.94846408e-03, 6.75802259e-03, -5.12090744e-03,
|
||||
3.06212017e-03, -4.61751036e-03, -3.23935202e-03
|
||||
-5.71034756e-03, 4.50129062e-03, -8.07590690e-03,
|
||||
-6.94835978e-03, 6.75921654e-03, -5.12148207e-03,
|
||||
3.06083867e-03, -4.61710012e-03, -3.23932176e-03
|
||||
], [
|
||||
-2.37229262e-02, -5.88545902e-03, -3.71685164e-05,
|
||||
2.01788824e-02, -6.75938500e-04, 1.06682768e-02,
|
||||
-1.42627759e-02, -2.69629639e-02, -1.45033952e-02
|
||||
-2.37224493e-02, -5.88587578e-03, -3.70525813e-05,
|
||||
2.01787278e-02, -6.76127791e-04, 1.06675029e-02,
|
||||
-1.42634306e-02, -2.69631632e-02, -1.45033058e-02
|
||||
], [
|
||||
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
|
||||
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
|
||||
@ -654,17 +657,17 @@ class AttentionWrapperTest(test.TestCase):
|
||||
dtype=float32)),
|
||||
attention=array(
|
||||
[[
|
||||
0.00173094, 0.00202626, 0.00217242, -0.00371918, -0.00361262,
|
||||
-0.00638988
|
||||
0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369,
|
||||
-0.00654224
|
||||
], [
|
||||
0.00516733, -0.00219496, 0.0035531, -0.00283683, 0.00318047,
|
||||
-0.00516856
|
||||
0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879,
|
||||
-0.00517592
|
||||
], [
|
||||
-0.00111046, -0.00070923, -0.00198852, -0.00188148, -0.00495751,
|
||||
-0.0092879
|
||||
-0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048,
|
||||
-0.00925842
|
||||
], [
|
||||
0.00098565, -0.00435873, -0.00141089, 0.00064208, -0.00216198,
|
||||
0.00665505
|
||||
0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558,
|
||||
0.00659212
|
||||
], [
|
||||
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
|
||||
-0.00026981
|
||||
@ -678,7 +681,7 @@ class AttentionWrapperTest(test.TestCase):
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
attention_mechanism_depth=9,
|
||||
name="testLuongNormalized")
|
||||
name="testLuongScaled")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -176,19 +176,18 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
"Effective Approaches to Attention-based Neural Machine Translation."
|
||||
EMNLP 2015. https://arxiv.org/abs/1508.04025
|
||||
|
||||
The second is the normalized form. This form is inspired by the
|
||||
normalization proposed for Bahdanau attention in
|
||||
|
||||
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
|
||||
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
|
||||
(Eq. 15).
|
||||
The second is the scaled form inspired partly by the normalized form of
|
||||
Bahdanau attention.
|
||||
|
||||
To enable the second form, construct the object with parameter
|
||||
`normalize=True`.
|
||||
`scale=True`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, memory, memory_sequence_length=None,
|
||||
normalize=False, attention_r_initializer=None,
|
||||
def __init__(self,
|
||||
num_units,
|
||||
memory,
|
||||
memory_sequence_length=None,
|
||||
scale=False,
|
||||
name="LuongAttention"):
|
||||
"""Construct the AttentionMechanism mechanism.
|
||||
|
||||
@ -199,9 +198,7 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
memory_sequence_length (optional): Sequence lengths for the batch entries
|
||||
in memory. If provided, the memory tensor rows are masked with zeros
|
||||
for values past the respective sequence lengths.
|
||||
normalize: Python boolean. Whether to normalize the energy term.
|
||||
attention_r_initializer: Initial value of the post-normalization bias
|
||||
when normalizing. Default is `0`.
|
||||
scale: Python boolean. Whether to scale the energy term.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
# For LuongAttention, we only transform the memory layer; thus
|
||||
@ -214,18 +211,8 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
memory_sequence_length=memory_sequence_length,
|
||||
name=name)
|
||||
self._num_units = num_units
|
||||
self._normalize = normalize
|
||||
self._scale = scale
|
||||
self._name = name
|
||||
if normalize and attention_r_initializer is None:
|
||||
attention_r_initializer = 0
|
||||
if normalize:
|
||||
with ops.name_scope(
|
||||
name, "LuongAttentionInit",
|
||||
[memory, attention_r_initializer]):
|
||||
attention_r_initializer = ops.convert_to_tensor(
|
||||
attention_r_initializer, dtype=self.values.dtype,
|
||||
name="attention_r_initializer")
|
||||
self._attention_r_initializer = attention_r_initializer
|
||||
|
||||
def __call__(self, query):
|
||||
"""Score the query based on the keys and values.
|
||||
@ -268,16 +255,11 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
score = math_ops.matmul(query, self.keys, transpose_b=True)
|
||||
score = array_ops.squeeze(score, [1])
|
||||
|
||||
if self._normalize:
|
||||
# Scalar used in weight normalization
|
||||
if self._scale:
|
||||
# Scalar used in weight scaling
|
||||
g = variable_scope.get_variable(
|
||||
"attention_g", dtype=dtype,
|
||||
initializer=math.sqrt((1. / self._num_units)))
|
||||
# Scalar bias added to attention scores
|
||||
r = variable_scope.get_variable(
|
||||
"attention_r", dtype=dtype,
|
||||
initializer=self._attention_r_initializer)
|
||||
score = g * score + r
|
||||
"attention_g", dtype=dtype, initializer=1.)
|
||||
score = g * score
|
||||
|
||||
return score
|
||||
|
||||
@ -292,18 +274,23 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
"Neural Machine Translation by Jointly Learning to Align and Translate."
|
||||
ICLR 2015. https://arxiv.org/abs/1409.0473
|
||||
|
||||
The second is the normalized form, Raffel attention, as described in:
|
||||
The second is the normalized form. This form is inspired by the
|
||||
weight normalization article:
|
||||
|
||||
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
|
||||
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
|
||||
(Eq. 15).
|
||||
Tim Salimans, Diederik P. Kingma.
|
||||
"Weight Normalization: A Simple Reparameterization to Accelerate
|
||||
Training of Deep Neural Networks."
|
||||
https://arxiv.org/abs/1602.07868
|
||||
|
||||
To enable the second form, construct the object with parameter
|
||||
`normalize=True`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, memory, memory_sequence_length=None,
|
||||
normalize=False, attention_r_initializer=None,
|
||||
def __init__(self,
|
||||
num_units,
|
||||
memory,
|
||||
memory_sequence_length=None,
|
||||
normalize=False,
|
||||
name="BahdanauAttention"):
|
||||
"""Construct the Attention mechanism.
|
||||
|
||||
@ -315,8 +302,6 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
in memory. If provided, the memory tensor rows are masked with zeros
|
||||
for values past the respective sequence lengths.
|
||||
normalize: Python boolean. Whether to normalize the energy term.
|
||||
attention_r_initializer: Initial value of the post-normalization bias
|
||||
when normalizing. Default is `0`.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
super(BahdanauAttention, self).__init__(
|
||||
@ -330,15 +315,6 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
self._num_units = num_units
|
||||
self._normalize = normalize
|
||||
self._name = name
|
||||
if normalize and attention_r_initializer is None:
|
||||
attention_r_initializer = 0
|
||||
if normalize:
|
||||
with ops.name_scope(name, "BahdanauAttentionInit",
|
||||
[memory, attention_r_initializer]):
|
||||
attention_r_initializer = ops.convert_to_tensor(
|
||||
attention_r_initializer, dtype=self.values.dtype,
|
||||
name="attention_r_initializer")
|
||||
self._attention_r_initializer = attention_r_initializer
|
||||
|
||||
def __call__(self, query):
|
||||
"""Score the query based on the keys and values.
|
||||
@ -350,7 +326,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
score: Tensor of dtype matching `self.values` and shape
|
||||
`[batch_size, self.num_units]`.
|
||||
"""
|
||||
with variable_scope.variable_scope(None, "bahndahau_attention", [query]):
|
||||
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
|
||||
processed_query = self.query_layer(query) if self.query_layer else query
|
||||
dtype = processed_query.dtype
|
||||
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
|
||||
@ -366,15 +342,11 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
b = variable_scope.get_variable(
|
||||
"attention_b", [self._num_units], dtype=dtype,
|
||||
initializer=init_ops.zeros_initializer())
|
||||
# Scalar bias added to attention scores
|
||||
r = variable_scope.get_variable(
|
||||
"attention_r", dtype=dtype,
|
||||
initializer=self._attention_r_initializer)
|
||||
# normed_v = g * v / ||v||
|
||||
normed_v = g * v * math_ops.rsqrt(
|
||||
math_ops.reduce_sum(math_ops.square(v)))
|
||||
score = math_ops.reduce_sum(
|
||||
normed_v * math_ops.tanh(self.keys + processed_query + b), [2]) + r
|
||||
normed_v * math_ops.tanh(self.keys + processed_query + b), [2])
|
||||
else:
|
||||
score = math_ops.reduce_sum(
|
||||
v * math_ops.tanh(self.keys + processed_query), [2])
|
||||
|
Loading…
Reference in New Issue
Block a user