[tf contrib seq2seq] Updates to AttentionMechanism API

* Move the probability_fn into the AttentionMechanism and out of AttentionWrapper.
* Propagate the previous alignment through the attention state and pass it to the AttentionMechanism.
* Simplify the unit tests.

Necessary for colin's monotonic attention.

PiperOrigin-RevId: 155920141
This commit is contained in:
Eugene Brevdo 2017-05-12 15:39:26 -07:00 committed by TensorFlower Gardener
parent be9af9685c
commit 49a259db15
2 changed files with 191 additions and 662 deletions

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import sys
import collections
import functools
import numpy as np
@ -46,15 +46,27 @@ BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid
float32 = np.float32
int32 = np.int32
array = np.array
dtype = np.dtype
class ResultSummary(
collections.namedtuple('ResultSummary', ('shape', 'dtype', 'mean'))):
pass
def get_result_summary(x):
if isinstance(x, np.ndarray):
return ResultSummary(x.shape, x.dtype, x.mean())
return x
class AttentionWrapperTest(test.TestCase):
def assertAllClose(self, *args, **kwargs):
kwargs["atol"] = 1e-4 # For GPU tests
kwargs["rtol"] = 1e-4 # For GPU tests
return super(AttentionWrapperTest, self).assertAllClose(
*args, **kwargs)
def assertAllCloseOrEqual(self, x, y, **kwargs):
if isinstance(x, np.ndarray) or isinstance(x, float):
return super(AttentionWrapperTest, self).assertAllClose(x, y, **kwargs)
else:
self.assertAllEqual(x, y, **kwargs)
def testAttentionWrapperState(self):
num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access
@ -71,7 +83,7 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_size=6,
name=""):
name=''):
encoder_sequence_length = [3, 2, 3, 1, 0]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
@ -98,7 +110,7 @@ class AttentionWrapperTest(test.TestCase):
with self.test_session(use_gpu=True) as sess:
with vs.variable_scope(
"root",
'root',
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
cell = core_rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
@ -147,192 +159,53 @@ class AttentionWrapperTest(test.TestCase):
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"final_outputs": final_outputs,
"final_state": final_state,
"state_alignment_history": state_alignment_history,
'final_outputs': final_outputs,
'final_state': final_state,
'state_alignment_history': state_alignment_history,
})
print("Copy/paste (%s)\nexpected_final_output = " % name,
sess_results["final_outputs"])
sys.stdout.flush()
print("Copy/paste (%s)\nexpected_final_state = " % name,
sess_results["final_state"])
sys.stdout.flush()
print("Copy/paste (%s)\nexpected_final_alignment_history = " % name,
np.asarray(sess_results["state_alignment_history"]))
sys.stdout.flush()
nest.map_structure(self.assertAllClose, expected_final_output,
sess_results["final_outputs"])
nest.map_structure(self.assertAllClose, expected_final_state,
sess_results["final_state"])
final_output_info = nest.map_structure(get_result_summary,
sess_results['final_outputs'])
final_state_info = nest.map_structure(get_result_summary,
sess_results['final_state'])
print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info))
print('expected_final_state = %s' % str(final_state_info))
nest.map_structure(self.assertAllCloseOrEqual, expected_final_output,
final_output_info)
nest.map_structure(self.assertAllCloseOrEqual, expected_final_state,
final_state_info)
if alignment_history: # by default, the wrapper emits attention as output
self.assertAllClose(
final_alignment_history_info = nest.map_structure(
get_result_summary, sess_results['state_alignment_history'])
print('expected_final_alignment_history = %s' %
str(final_alignment_history_info))
nest.map_structure(
self.assertAllCloseOrEqual,
# outputs are batch major but the stacked TensorArray is time major
sess_results["state_alignment_history"],
expected_final_alignment_history)
expected_final_alignment_history,
final_alignment_history_info)
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
2.04633363e-03, 1.89259532e-03, 2.09550979e-03, -3.81628517e-03,
-4.36160620e-03, -6.43933658e-03
], [
2.41885195e-03, 2.02089013e-03, 2.05879519e-03, -3.85483308e-03,
-3.51473060e-03, -6.14458136e-03
], [
2.02294230e-03, 2.06955452e-03, 2.34797411e-03, -3.62816593e-03,
-3.80352931e-03, -6.27150526e-03
]], [[
4.89025004e-03, -1.97221269e-03, 3.34283570e-03,
-2.79326970e-03, 3.63148772e-03, -4.79645561e-03
], [
5.13446378e-03, -2.03941623e-03, 3.51774949e-03,
-2.83448119e-03, 3.14159272e-03, -5.31486655e-03
], [
5.20701287e-03, -2.21262546e-03, 3.58187454e-03,
-2.85831164e-03, 3.20822699e-03, -5.20829484e-03
]], [[
-1.34046993e-03, -9.99792013e-04, -2.11631414e-03,
-1.85202830e-03, -5.26227616e-03, -9.08544939e-03
], [
-1.35486713e-03, -1.04408595e-03, -1.96779310e-03,
-1.80004584e-03, -5.61304903e-03, -9.34211537e-03
], [
-1.12452905e-03, -7.68281636e-04, -1.99770415e-03,
-1.88058324e-03, -5.01882844e-03, -9.32228006e-03
]], [[
1.52967637e-03, -3.97213362e-03, -9.64699371e-04,
8.51419638e-04, -1.29806029e-03, 6.56482670e-03
], [
1.22562144e-03, -4.56351135e-03, -1.08190742e-03,
8.27267300e-04, -2.10060296e-03, 6.43097097e-03
], [
9.93521884e-04, -4.37386986e-03, -1.41534151e-03,
6.44790183e-04, -2.16482091e-03, 6.68301852e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00083043973),
sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-2.18977481e-02, -8.04181397e-03, -1.48273818e-03,
1.61075518e-02, -1.37986457e-02, -7.57964421e-03,
-8.28644261e-03, -1.18742418e-02, 1.78838037e-02
], [
1.74201727e-02, -1.41931782e-02, -3.88098788e-03,
3.19711640e-02, -3.54694054e-02, -2.14694049e-02,
-6.21706853e-03, -1.69323490e-03, -1.94494929e-02
], [
-1.14532551e-02, 8.77828151e-03, -1.62972715e-02,
-1.39963031e-02, 1.34832524e-02, -1.04488730e-02,
6.16201758e-03, -9.41041857e-03, -6.57599326e-03
], [
-4.74753827e-02, -1.19123599e-02, -7.40140676e-05,
4.10552323e-02, -1.36711076e-03, 2.11795494e-02,
-2.80460101e-02, -5.44509329e-02, -2.91906092e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.09847616e-02, -3.97357112e-03, -7.54502777e-04,
7.91223347e-03, -7.02199014e-03, -3.80705344e-03,
-4.22102772e-03, -6.05491130e-03, 8.92073940e-03
], [
8.68115202e-03, -7.16950046e-03, -1.88387593e-03,
1.62680726e-02, -1.76830068e-02, -1.06620435e-02,
-3.07523785e-03, -8.46023730e-04, -9.99386702e-03
], [
-5.71225956e-03, 4.50055022e-03, -8.07653368e-03,
-6.94842264e-03, 6.75687613e-03, -5.12083014e-03,
3.06244940e-03, -4.61752573e-03, -3.23935854e-03
], [
-2.37231534e-02, -5.88526297e-03, -3.72226204e-05,
2.01789513e-02, -6.75848918e-04, 1.06686372e-02,
-1.42624676e-02, -2.69628745e-02, -1.45034352e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.00202294, 0.00206955, 0.00234797, -0.00362817, -0.00380353,
-0.00627151
], [
0.00520701, -0.00221263, 0.00358187, -0.00285831, 0.00320823,
-0.00520829
], [
-0.00112453, -0.00076828, -0.0019977, -0.00188058, -0.00501883,
-0.00932228
], [
0.00099352, -0.00437387, -0.00141534, 0.00064479, -0.00216482,
0.00668302
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
c=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0039763632),
h=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0019849765)),
attention=ResultSummary(
shape=(5, 6), dtype=dtype('float32'), mean=-0.00081052497),
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
expected_final_alignment_history = [[[
0.12586178, 0.12272788, 0.1271652, 0.12484902, 0.12484902, 0.12484902,
0.12484902, 0.12484902
], [
0.12612638, 0.12516938, 0.12478404, 0.12478404, 0.12478404, 0.12478404,
0.12478404, 0.12478404
], [
0.12595113, 0.12515794, 0.1255464, 0.1246689, 0.1246689, 0.1246689,
0.1246689, 0.1246689
], [
0.12492912, 0.12501013, 0.12501013, 0.12501013, 0.12501013, 0.12501013,
0.12501013, 0.12501013
], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[
0.12586173, 0.12272781, 0.12716517, 0.12484905, 0.12484905, 0.12484905,
0.12484905, 0.12484905
], [
0.12612617, 0.1251694, 0.12478408, 0.12478408, 0.12478408, 0.12478408,
0.12478408, 0.12478408
], [
0.12595108, 0.12515777, 0.1255464, 0.12466895, 0.12466895, 0.12466895,
0.12466895, 0.12466895
], [
0.12492914, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012,
0.12501012, 0.12501012
], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[
0.12586181, 0.12272815, 0.12716556, 0.12484891, 0.12484891, 0.12484891,
0.12484891, 0.12484891
], [
0.12612608, 0.12516941, 0.12478409, 0.12478409, 0.12478409, 0.12478409,
0.12478409, 0.12478409
], [
0.12595116, 0.12515792, 0.12554643, 0.1246689, 0.1246689, 0.1246689,
0.1246689, 0.1246689
], [
0.1249292, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012,
0.12501012, 0.12501012
], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]]]
expected_final_alignment_history = ResultSummary(
shape=(3, 5, 8), dtype=dtype('float32'), mean=0.12500001)
self._testWithAttention(
create_attention_mechanism,
@ -340,263 +213,54 @@ class AttentionWrapperTest(test.TestCase):
expected_final_state,
alignment_history=True,
expected_final_alignment_history=expected_final_alignment_history,
name="testBahdanauNotNormalized")
name='testBahdanauNotNormalized')
def testBahdanauNormalized(self):
create_attention_mechanism = functools.partial(
wrapper.BahdanauAttention, normalize=True)
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.27064800e-02, 3.57783446e-03, 8.22613202e-03, -1.61504047e-03,
-1.12555185e-02, -3.92740499e-03
], [
1.30781950e-02, 3.70747922e-03, 8.18992872e-03, -1.65389013e-03,
-1.04098395e-02, -3.63383139e-03
], [
1.26833543e-02, 3.75790196e-03, 8.48123431e-03, -1.42690970e-03,
-1.07016256e-02, -3.76088684e-03
]], [[
6.88417302e-03, -2.04071682e-03, 4.17768257e-03,
-4.51408979e-03, 4.90086433e-03, -6.85973791e-03
], [
7.12782983e-03, -2.10783770e-03, 4.35227761e-03,
-4.55496181e-03, 4.41066315e-03, -7.37757795e-03
], [
7.20011396e-03, -2.28102156e-03, 4.41620918e-03,
-4.57867794e-03, 4.47713351e-03, -7.27072079e-03
]], [[
-2.20676698e-03, -1.43745833e-03, -1.99429039e-03,
-1.44722988e-03, -7.45461835e-03, -9.80243273e-03
], [
-2.22120387e-03, -1.48139545e-03, -1.84528576e-03,
-1.39490096e-03, -7.80559657e-03, -1.00586927e-02
], [
-1.99079141e-03, -1.20571791e-03, -1.87507609e-03,
-1.47541985e-03, -7.21158786e-03, -1.00391749e-02
]], [[
1.48755650e-03, -3.89118027e-03, -9.40889120e-04,
8.36852356e-04, -1.28285377e-03, 6.41521579e-03
], [
1.18351437e-03, -4.48258361e-03, -1.05809816e-03,
8.12723883e-04, -2.08540238e-03, 6.28142804e-03
], [
9.51444614e-04, -4.29300033e-03, -1.39154412e-03,
6.30271854e-04, -2.14963360e-03, 6.53359853e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[0, 0, 0], [0, 0, 0], [1, 3, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00040482997),
sample_id=ResultSummary(
shape=(5, 3), dtype=dtype('int32'), mean=1.8666666666666667))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-2.19953191e-02, -7.81358499e-03, -1.42740645e-03,
1.62037201e-02, -1.38600282e-02, -7.60386931e-03,
-8.42390209e-03, -1.18884994e-02, 1.78821683e-02
], [
1.74096227e-02, -1.41773149e-02, -3.89175024e-03,
3.19635086e-02, -3.54669318e-02, -2.14924756e-02,
-6.20695669e-03, -1.73213519e-03, -1.94583312e-02
], [
-1.14590004e-02, 8.76899902e-03, -1.62825100e-02,
-1.39863417e-02, 1.34333782e-02, -1.04652103e-02,
6.13503950e-03, -9.39247012e-03, -6.57595927e-03
], [
-4.74739373e-02, -1.19136302e-02, -7.36713409e-05,
4.10547927e-02, -1.36768632e-03, 2.11772211e-02,
-2.80480143e-02, -5.44514954e-02, -2.91903671e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.10325804e-02, -3.86056723e-03, -7.26287195e-04,
7.95945339e-03, -7.05253659e-03, -3.81913339e-03,
-4.29130904e-03, -6.06246945e-03, 8.91948957e-03
], [
8.67583323e-03, -7.16136536e-03, -1.88911252e-03,
1.62639488e-02, -1.76817775e-02, -1.06735229e-02,
-3.07015004e-03, -8.65494134e-04, -9.99815390e-03
], [
-5.71519835e-03, 4.49585915e-03, -8.06909613e-03,
-6.94347266e-03, 6.73189852e-03, -5.12895826e-03,
3.04909074e-03, -4.60868096e-03, -3.23936995e-03
], [
-2.37224363e-02, -5.88588836e-03, -3.70502457e-05,
2.01787297e-02, -6.76134136e-04, 1.06674768e-02,
-1.42634623e-02, -2.69631669e-02, -1.45033086e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.01268335, 0.0037579, 0.00848123, -0.00142691, -0.01070163,
-0.00376089
], [
0.00720011, -0.00228102, 0.00441621, -0.00457868, 0.00447713,
-0.00727072
], [
-0.00199079, -0.00120572, -0.00187508, -0.00147542, -0.00721159,
-0.01003917
], [
0.00095144, -0.004293, -0.00139154, 0.00063027, -0.00214963,
0.0065336
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
c=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0039785588),
h=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0019861322)),
attention=ResultSummary(
shape=(5, 6), dtype=dtype('float32'), mean=-0.00038488387),
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
create_attention_mechanism,
expected_final_output,
expected_final_state,
name="testBahdanauNormalized")
name='testBahdanauNormalized')
def testLuongNotNormalized(self):
create_attention_mechanism = wrapper.LuongAttention
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03,
-4.17229906e-03, -6.65769773e-03
], [
1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03,
-3.15058464e-03, -6.34974428e-03
], [
2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03,
-3.96897132e-03, -6.12034509e-03
]], [[
4.76492243e-03, -1.97180966e-03, 3.29327444e-03,
-2.68205139e-03, 3.55229783e-03, -4.66645230e-03
], [
5.24956919e-03, -2.00631656e-03, 3.53828911e-03,
-2.96283513e-03, 3.20920302e-03, -5.43697737e-03
], [
5.30424621e-03, -2.17913301e-03, 3.59509978e-03,
-2.97106663e-03, 3.26450402e-03, -5.31189423e-03
]], [[
-1.36440888e-03, -9.75572329e-04, -2.11284542e-03,
-1.84616144e-03, -5.31351101e-03, -9.12462734e-03
], [
-1.41863467e-03, -1.11081311e-03, -1.94056751e-03,
-1.74311269e-03, -5.76282106e-03, -9.29267984e-03
], [
-1.12129003e-03, -8.15156149e-04, -2.01535341e-03,
-1.89556007e-03, -5.04226238e-03, -9.37188603e-03
]], [[
1.55163277e-03, -4.01433324e-03, -9.77111282e-04,
8.59013060e-04, -1.30598655e-03, 6.64281659e-03
], [
1.26811734e-03, -4.64518648e-03, -1.10593368e-03,
8.41954607e-04, -2.11594440e-03, 6.58190623e-03
], [
1.02682540e-03, -4.43787826e-03, -1.43417739e-03,
6.56281307e-04, -2.17684195e-03, 6.80128345e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00084602338),
sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-2.18942575e-02, -8.05099495e-03, -1.48526859e-03,
1.61030665e-02, -1.37967104e-02, -7.57982396e-03,
-8.28088820e-03, -1.18743815e-02, 1.78839806e-02
], [
1.74203254e-02, -1.41929490e-02, -3.88103351e-03,
3.19709182e-02, -3.54691371e-02, -2.14697979e-02,
-6.21709181e-03, -1.69324467e-03, -1.94495786e-02
], [
-1.14536462e-02, 8.77809525e-03, -1.62965059e-02,
-1.39955431e-02, 1.34810507e-02, -1.04491040e-02,
6.16097450e-03, -9.40943789e-03, -6.57613343e-03
], [
-4.74765450e-02, -1.19113335e-02, -7.42897391e-05,
4.10555862e-02, -1.36665069e-03, 2.11814232e-02,
-2.80444007e-02, -5.44504896e-02, -2.91908123e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.09830676e-02, -3.97811923e-03, -7.55793473e-04,
7.91002903e-03, -7.02103321e-03, -3.80714820e-03,
-4.21818346e-03, -6.05497835e-03, 8.92084371e-03
], [
8.68122280e-03, -7.16937613e-03, -1.88389909e-03,
1.62679367e-02, -1.76828820e-02, -1.06622437e-02,
-3.07524228e-03, -8.46030540e-04, -9.99389403e-03
], [
-5.71245840e-03, 4.50045895e-03, -8.07614625e-03,
-6.94804778e-03, 6.75577158e-03, -5.12094703e-03,
3.06193763e-03, -4.61703911e-03, -3.23943049e-03
], [
-2.37237271e-02, -5.88475820e-03, -3.73612711e-05,
2.01791357e-02, -6.75620860e-04, 1.06695695e-02,
-1.42616741e-02, -2.69626491e-02, -1.45035451e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897,
-0.00612035
], [
0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645,
-0.00531189
], [
-0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226,
-0.00937189
], [
0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684,
0.00680128
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
c=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0039764317),
h=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0019850098)),
attention=ResultSummary(
shape=(5, 6), dtype=dtype('float32'), mean=-0.00080144603),
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
@ -604,132 +268,27 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output,
expected_final_state,
attention_mechanism_depth=9,
name="testLuongNotNormalized")
name='testLuongNotNormalized')
def testLuongScaled(self):
create_attention_mechanism = functools.partial(
wrapper.LuongAttention, scale=True)
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03,
-4.17229906e-03, -6.65769773e-03
], [
1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03,
-3.15058464e-03, -6.34974428e-03
], [
2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03,
-3.96897132e-03, -6.12034509e-03
]], [[
4.76492243e-03, -1.97180966e-03, 3.29327444e-03,
-2.68205139e-03, 3.55229783e-03, -4.66645230e-03
], [
5.24956919e-03, -2.00631656e-03, 3.53828911e-03,
-2.96283513e-03, 3.20920302e-03, -5.43697737e-03
], [
5.30424621e-03, -2.17913301e-03, 3.59509978e-03,
-2.97106663e-03, 3.26450402e-03, -5.31189423e-03
]], [[
-1.36440888e-03, -9.75572329e-04, -2.11284542e-03,
-1.84616144e-03, -5.31351101e-03, -9.12462734e-03
], [
-1.41863467e-03, -1.11081311e-03, -1.94056751e-03,
-1.74311269e-03, -5.76282106e-03, -9.29267984e-03
], [
-1.12129003e-03, -8.15156149e-04, -2.01535341e-03,
-1.89556007e-03, -5.04226238e-03, -9.37188603e-03
]], [[
1.55163277e-03, -4.01433324e-03, -9.77111282e-04,
8.59013060e-04, -1.30598655e-03, 6.64281659e-03
], [
1.26811734e-03, -4.64518648e-03, -1.10593368e-03,
8.41954607e-04, -2.11594440e-03, 6.58190623e-03
], [
1.02682540e-03, -4.43787826e-03, -1.43417739e-03,
6.56281307e-04, -2.17684195e-03, 6.80128345e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00084602338),
sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-2.18942575e-02, -8.05099495e-03, -1.48526859e-03,
1.61030665e-02, -1.37967104e-02, -7.57982396e-03,
-8.28088820e-03, -1.18743815e-02, 1.78839806e-02
], [
1.74203254e-02, -1.41929490e-02, -3.88103351e-03,
3.19709182e-02, -3.54691371e-02, -2.14697979e-02,
-6.21709181e-03, -1.69324467e-03, -1.94495786e-02
], [
-1.14536462e-02, 8.77809525e-03, -1.62965059e-02,
-1.39955431e-02, 1.34810507e-02, -1.04491040e-02,
6.16097450e-03, -9.40943789e-03, -6.57613343e-03
], [
-4.74765450e-02, -1.19113335e-02, -7.42897391e-05,
4.10555862e-02, -1.36665069e-03, 2.11814232e-02,
-2.80444007e-02, -5.44504896e-02, -2.91908123e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.09830676e-02, -3.97811923e-03, -7.55793473e-04,
7.91002903e-03, -7.02103321e-03, -3.80714820e-03,
-4.21818346e-03, -6.05497835e-03, 8.92084371e-03
], [
8.68122280e-03, -7.16937613e-03, -1.88389909e-03,
1.62679367e-02, -1.76828820e-02, -1.06622437e-02,
-3.07524228e-03, -8.46030540e-04, -9.99389403e-03
], [
-5.71245840e-03, 4.50045895e-03, -8.07614625e-03,
-6.94804778e-03, 6.75577158e-03, -5.12094703e-03,
3.06193763e-03, -4.61703911e-03, -3.23943049e-03
], [
-2.37237271e-02, -5.88475820e-03, -3.73612711e-05,
2.01791357e-02, -6.75620860e-04, 1.06695695e-02,
-1.42616741e-02, -2.69626491e-02, -1.45035451e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897,
-0.00612035
], [
0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645,
-0.00531189
], [
-0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226,
-0.00937189
], [
0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684,
0.00680128
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
c=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0039764317),
h=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0019850098)),
attention=ResultSummary(
shape=(5, 6), dtype=dtype('float32'), mean=-0.00080144603),
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
@ -737,116 +296,27 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output,
expected_final_state,
attention_mechanism_depth=9,
name="testLuongScaled")
name='testLuongScaled')
def testNotUseAttentionLayer(self):
create_attention_mechanism = wrapper.BahdanauAttention
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
-0.24223405, -0.07791166, 0.15451428, 0.24738294, 0.30900395,
-0.24685201, 0.04992372, 0.18749543, -0.15878429, -0.13678923
], [
-0.2422339, -0.07791159, 0.15451418, 0.24738279, 0.30900383,
-0.24685188, 0.04992369, 0.18749531, -0.15878411, -0.13678911
], [
-0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475,
-0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965
]], [[
0.40035266, 0.12299616, -0.06085059, -0.09197108, 0.11368551,
-0.15302914, 0.00566157, -0.26885766, 0.08546552, 0.18886778
], [
0.40035242, 0.12299603, -0.06085056, -0.09197091, 0.11368536,
-0.15302882, 0.0056615, -0.26885763, 0.08546554, 0.18886763
], [
0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532,
-0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761
]], [[
-0.4311333, 0.07519469, -0.01551808, 0.1913045, -0.02693807,
-0.21668895, -0.02155721, 0.0013397, 0.21180844, 0.25578707
], [
-0.43113309, 0.07519454, -0.01551818, 0.19130446, -0.0269379,
-0.21668854, -0.021557, 0.00133975, 0.21180828, 0.25578681
], [
-0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798,
-0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696
]], [[
0.07059932, 0.16451572, 0.01174669, 0.04646531, 0.1427598,
0.0794456, -0.10852993, 0.15306188, 0.02151393, -0.05590061
], [
0.07059933, 0.16451576, 0.01174669, 0.04646532, 0.14275983,
0.07944562, -0.10852996, 0.15306193, 0.02151394, -0.05590062
], [
0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599,
0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065
]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
dtype=float32),
sample_id=array(
[[4, 4, 4], [0, 0, 0], [9, 9, 9], [1, 1, 1], [0, 0, 0]],
dtype=int32))
rnn_output=ResultSummary(
shape=(5, 3, 10), dtype=dtype('float32'), mean=0.019546926),
sample_id=ResultSummary(
shape=(5, 3), dtype=dtype('int32'), mean=2.7999999999999998))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-0.0181195, -0.01675365, -0.00510353, 0.01559796,
-0.01251448, -0.00437002, -0.01243257, -0.01720199,
0.02274928
], [
0.01259979, -0.00839985, -0.00374037, 0.03136262,
-0.03486227, -0.02466441, -0.00496157, -0.00461032,
-0.02098336
], [
-0.00781067, 0.00315682, -0.0138283, -0.01149793,
0.00485562, -0.01343193, 0.0085915, -0.00632846, -0.01052086
], [
-0.04184828, -0.01223641, 0.0009445, 0.03911434, 0.0043249,
0.02220661, -0.03006243, -0.05418363, -0.02615385
], [
0.02282745, -0.00143833, 0.01918138, 0.00545033,
-0.01258384, -0.00303765, -0.00511231, 0.02166323,
0.01638841
]],
dtype=float32),
h=array(
[[
-0.00910065, -0.00827571, -0.00259689, 0.00764857,
-0.00635579, -0.00218579, -0.00633918, -0.00875511,
0.01134532
], [
0.00626597, -0.004241, -0.00181303, 0.01597157, -0.0173375,
-0.01224921, -0.00244522, -0.00231299, -0.0107822
], [
-0.00391383, 0.00162017, -0.00682621, -0.00570264,
0.00244099, -0.00659772, 0.00426475, -0.00309861,
-0.00520028
], [
-0.02087484, -0.00603306, 0.00047561, 0.01920062,
0.00213875, 0.01115329, -0.0152659, -0.02687523, -0.01297523
], [
0.01138975, -0.00070959, 0.00986007, 0.0027323, -0.00645386,
-0.00152054, -0.00257339, 0.01103063, 0.00800891
]],
dtype=float32)),
attention=array(
[[
-0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475,
-0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965
], [
0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532,
-0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761
], [
-0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798,
-0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696
], [
0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599,
0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065
], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
dtype=float32),
c=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0041728448),
h=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.002085865)),
attention=ResultSummary(
shape=(5, 10), dtype=dtype('float32'), mean=0.019546915),
time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=())
self._testWithAttention(
@ -854,8 +324,8 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output,
expected_final_state,
attention_layer_size=None,
name="testNotUseAttentionLayer")
name='testNotUseAttentionLayer')
if __name__ == "__main__":
if __name__ == '__main__':
test.main()

View File

@ -121,8 +121,13 @@ class _BaseAttentionMechanism(AttentionMechanism):
2. Preprocessing and storing the memory.
"""
def __init__(self, query_layer, memory, memory_sequence_length=None,
memory_layer=None, check_inner_dims_defined=True,
def __init__(self,
query_layer,
memory,
probability_fn,
memory_sequence_length=None,
memory_layer=None,
check_inner_dims_defined=True,
name=None):
"""Construct base AttentionMechanism class.
@ -132,6 +137,9 @@ class _BaseAttentionMechanism(AttentionMechanism):
provided, the shape of `query` must match that of `memory_layer`.
memory: The memory to query; usually the output of an RNN encoder. This
tensor should be shaped `[batch_size, max_time, ...]`.
probability_fn: A `callable`. Converts the score and previous alignments
to probabilities. Its signature should be:
`probabilities = probability_fn(score, previous_alignments)`.
memory_sequence_length (optional): Sequence lengths for the batch entries
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
@ -154,6 +162,10 @@ class _BaseAttentionMechanism(AttentionMechanism):
"memory_layer is not a Layer: %s" % type(memory_layer).__name__)
self._query_layer = query_layer
self._memory_layer = memory_layer
if not callable(probability_fn):
raise TypeError("probability_fn must be callable, saw type: %s" %
type(probability_fn).__name__)
self._probability_fn = probability_fn
with ops.name_scope(
name, "BaseAttentionMechanismInit", nest.flatten(memory)):
self._values = _prepare_memory(
@ -164,6 +176,8 @@ class _BaseAttentionMechanism(AttentionMechanism):
else self._values)
self._batch_size = (
self._keys.shape[0].value or array_ops.shape(self._keys)[0])
self._alignments_size = (self._keys.shape[1].value or
array_ops.shape(self._keys)[1])
@property
def memory_layer(self):
@ -185,6 +199,29 @@ class _BaseAttentionMechanism(AttentionMechanism):
def batch_size(self):
return self._batch_size
@property
def alignments_size(self):
return self._alignments_size
def initial_alignments(self, batch_size, dtype):
"""Creates the initial alignment values for the `AttentionWrapper` class.
This is important for AttentionMechanisms that use the previous alignment
to calculate the alignment at the next time step (e.g. monotonic attention).
The default behavior is to return a tensor of all zeros.
Args:
batch_size: `int32` scalar, the batch_size.
dtype: The `dtype`.
Returns:
A `dtype` tensor shaped `[batch_size, alignments_size]`
(`alignments_size` is the values' `max_time`).
"""
max_time = self._alignments_size
return _zero_state_tensors(max_time, batch_size, dtype)
class LuongAttention(_BaseAttentionMechanism):
"""Implements Luong-style (multiplicative) attention scoring.
@ -208,6 +245,7 @@ class LuongAttention(_BaseAttentionMechanism):
memory,
memory_sequence_length=None,
scale=False,
probability_fn=None,
name="LuongAttention"):
"""Construct the AttentionMechanism mechanism.
@ -219,31 +257,43 @@ class LuongAttention(_BaseAttentionMechanism):
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
scale: Python boolean. Whether to scale the energy term.
probability_fn: (optional) A `callable`. Converts the score to
probabilities. The default is @{tf.nn.softmax}. Other options include
@{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
Its signature should be: `probabilities = probability_fn(score)`.
name: Name to use when creating ops.
"""
# For LuongAttention, we only transform the memory layer; thus
# num_units **must** match expected the query depth.
if probability_fn is None:
probability_fn = nn_ops.softmax
wrapped_probability_fn = lambda score, _: probability_fn(score)
super(LuongAttention, self).__init__(
query_layer=None,
memory_layer=layers_core.Dense(
num_units, name="memory_layer", use_bias=False),
memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length,
name=name)
self._num_units = num_units
self._scale = scale
self._name = name
def __call__(self, query):
def __call__(self, query, previous_alignments):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
previous_alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
score: Tensor of dtype matching `self.values` and shape
`[batch_size, max_time]` (`max_time` is memory's `max_time`).
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
Raises:
ValueError: If `key` and `query` depths do not match.
@ -281,7 +331,8 @@ class LuongAttention(_BaseAttentionMechanism):
"attention_g", dtype=dtype, initializer=1.)
score = g * score
return score
alignments = self._probability_fn(score, previous_alignments)
return alignments
class BahdanauAttention(_BaseAttentionMechanism):
@ -311,6 +362,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
memory,
memory_sequence_length=None,
normalize=False,
probability_fn=None,
name="BahdanauAttention"):
"""Construct the Attention mechanism.
@ -322,30 +374,42 @@ class BahdanauAttention(_BaseAttentionMechanism):
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term.
probability_fn: (optional) A `callable`. Converts the score to
probabilities. The default is @{tf.nn.softmax}. Other options include
@{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
Its signature should be: `probabilities = probability_fn(score)`.
name: Name to use when creating ops.
"""
if probability_fn is None:
probability_fn = nn_ops.softmax
wrapped_probability_fn = lambda score, _: probability_fn(score)
super(BahdanauAttention, self).__init__(
query_layer=layers_core.Dense(
num_units, name="query_layer", use_bias=False),
memory_layer=layers_core.Dense(
num_units, name="memory_layer", use_bias=False),
memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length,
name=name)
self._num_units = num_units
self._normalize = normalize
self._name = name
def __call__(self, query):
def __call__(self, query, previous_alignments):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
previous_alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
score: Tensor of dtype matching `self.values` and shape
`[batch_size, max_time]` (`max_time` is memory's `max_time`).
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
"""
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
@ -373,20 +437,23 @@ class BahdanauAttention(_BaseAttentionMechanism):
score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query),
[2])
return score
alignments = self._probability_fn(score, previous_alignments)
return alignments
class AttentionWrapperState(
collections.namedtuple("AttentionWrapperState",
("cell_state", "attention", "time",
("cell_state", "attention", "time", "alignments",
"alignment_history"))):
"""`namedtuple` storing the state of a `AttentionWrapper`.
Contains:
- `cell_state`: The state of the wrapped `RNNCell`.
- `cell_state`: The state of the wrapped `RNNCell` at the previous time
step.
- `attention`: The attention emitted at the previous time step.
- `time`: int32 scalar containing the current time step.
- `alignments`: The alignment emitted at the previous time step.
- `alignment_history`: (if enabled) a `TensorArray` containing alignment
matrices from all time steps. Call `stack()` to convert to a `Tensor`.
"""
@ -443,7 +510,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
probability_fn=None,
output_attention=True,
initial_cell_state=None,
name=None):
@ -461,9 +527,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
time major `TensorArray` on which you must call `stack()`).
cell_input_fn: (optional) A `callable`. The default is:
`lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
probability_fn: (optional) A `callable`. Converts the score to
probabilities. The default is @{tf.nn.softmax}. Other options include
@{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
output_attention: Python bool. If `True` (default), the output at each
time step is the attention value. This is the behavior of Luong-style
attention mechanisms. If `False`, the output at each time step is
@ -495,13 +558,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
raise TypeError(
"cell_input_fn must be callable, saw type: %s"
% type(cell_input_fn).__name__)
if probability_fn is None:
probability_fn = nn_ops.softmax
else:
if not callable(probability_fn):
raise TypeError(
"probability_fn must be callable, saw type: %s"
% type(probability_fn).__name__)
if attention_layer_size is not None:
self._attention_layer = layers_core.Dense(
@ -514,7 +570,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
self._cell = cell
self._attention_mechanism = attention_mechanism
self._cell_input_fn = cell_input_fn
self._probability_fn = probability_fn
self._output_attention = output_attention
self._alignment_history = alignment_history
with ops.name_scope(name, "AttentionWrapperInit"):
@ -553,6 +608,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
cell_state=self._cell.state_size,
time=tensor_shape.TensorShape([]),
attention=self._attention_size,
alignments=self._attention_mechanism.alignments_size,
alignment_history=()) # alignment_history is sometimes a TensorArray
def zero_state(self, batch_size, dtype):
@ -586,6 +642,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
time=array_ops.zeros([], dtype=dtypes.int32),
attention=_zero_state_tensors(self._attention_size, batch_size,
dtype),
alignments=self._attention_mechanism.initial_alignments(
batch_size, dtype),
alignment_history=alignment_history)
def call(self, inputs, state):
@ -637,8 +695,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
cell_output = array_ops.identity(
cell_output, name="checked_cell_output")
score = self._attention_mechanism(cell_output)
alignments = self._probability_fn(score)
alignments = self._attention_mechanism(
cell_output, previous_alignments=state.alignments)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
@ -671,6 +729,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignments=alignments,
alignment_history=alignment_history)
if self._output_attention: