[tf contrib seq2seq] Updates to AttentionMechanism API

* Move the probability_fn into the AttentionMechanism and out of AttentionWrapper.
* Propagate the previous alignment through the attention state and pass it to the AttentionMechanism.
* Simplify the unit tests.

Necessary for colin's monotonic attention.

PiperOrigin-RevId: 155920141
This commit is contained in:
Eugene Brevdo 2017-05-12 15:39:26 -07:00 committed by TensorFlower Gardener
parent be9af9685c
commit 49a259db15
2 changed files with 191 additions and 662 deletions

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
# pylint: enable=unused-import # pylint: enable=unused-import
import sys import collections
import functools import functools
import numpy as np import numpy as np
@ -46,15 +46,27 @@ BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid
float32 = np.float32 float32 = np.float32
int32 = np.int32 int32 = np.int32
array = np.array array = np.array
dtype = np.dtype
class ResultSummary(
collections.namedtuple('ResultSummary', ('shape', 'dtype', 'mean'))):
pass
def get_result_summary(x):
if isinstance(x, np.ndarray):
return ResultSummary(x.shape, x.dtype, x.mean())
return x
class AttentionWrapperTest(test.TestCase): class AttentionWrapperTest(test.TestCase):
def assertAllClose(self, *args, **kwargs): def assertAllCloseOrEqual(self, x, y, **kwargs):
kwargs["atol"] = 1e-4 # For GPU tests if isinstance(x, np.ndarray) or isinstance(x, float):
kwargs["rtol"] = 1e-4 # For GPU tests return super(AttentionWrapperTest, self).assertAllClose(x, y, **kwargs)
return super(AttentionWrapperTest, self).assertAllClose( else:
*args, **kwargs) self.assertAllEqual(x, y, **kwargs)
def testAttentionWrapperState(self): def testAttentionWrapperState(self):
num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access
@ -71,7 +83,7 @@ class AttentionWrapperTest(test.TestCase):
alignment_history=False, alignment_history=False,
expected_final_alignment_history=None, expected_final_alignment_history=None,
attention_layer_size=6, attention_layer_size=6,
name=""): name=''):
encoder_sequence_length = [3, 2, 3, 1, 0] encoder_sequence_length = [3, 2, 3, 1, 0]
decoder_sequence_length = [2, 0, 1, 2, 3] decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5 batch_size = 5
@ -98,7 +110,7 @@ class AttentionWrapperTest(test.TestCase):
with self.test_session(use_gpu=True) as sess: with self.test_session(use_gpu=True) as sess:
with vs.variable_scope( with vs.variable_scope(
"root", 'root',
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
cell = core_rnn_cell.LSTMCell(cell_depth) cell = core_rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper( cell = wrapper.AttentionWrapper(
@ -147,192 +159,53 @@ class AttentionWrapperTest(test.TestCase):
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
sess_results = sess.run({ sess_results = sess.run({
"final_outputs": final_outputs, 'final_outputs': final_outputs,
"final_state": final_state, 'final_state': final_state,
"state_alignment_history": state_alignment_history, 'state_alignment_history': state_alignment_history,
}) })
print("Copy/paste (%s)\nexpected_final_output = " % name, final_output_info = nest.map_structure(get_result_summary,
sess_results["final_outputs"]) sess_results['final_outputs'])
sys.stdout.flush() final_state_info = nest.map_structure(get_result_summary,
print("Copy/paste (%s)\nexpected_final_state = " % name, sess_results['final_state'])
sess_results["final_state"]) print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info))
sys.stdout.flush() print('expected_final_state = %s' % str(final_state_info))
print("Copy/paste (%s)\nexpected_final_alignment_history = " % name, nest.map_structure(self.assertAllCloseOrEqual, expected_final_output,
np.asarray(sess_results["state_alignment_history"])) final_output_info)
sys.stdout.flush() nest.map_structure(self.assertAllCloseOrEqual, expected_final_state,
nest.map_structure(self.assertAllClose, expected_final_output, final_state_info)
sess_results["final_outputs"])
nest.map_structure(self.assertAllClose, expected_final_state,
sess_results["final_state"])
if alignment_history: # by default, the wrapper emits attention as output if alignment_history: # by default, the wrapper emits attention as output
self.assertAllClose( final_alignment_history_info = nest.map_structure(
get_result_summary, sess_results['state_alignment_history'])
print('expected_final_alignment_history = %s' %
str(final_alignment_history_info))
nest.map_structure(
self.assertAllCloseOrEqual,
# outputs are batch major but the stacked TensorArray is time major # outputs are batch major but the stacked TensorArray is time major
sess_results["state_alignment_history"], expected_final_alignment_history,
expected_final_alignment_history) final_alignment_history_info)
def testBahdanauNotNormalized(self): def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention create_attention_mechanism = wrapper.BahdanauAttention
expected_final_output = BasicDecoderOutput( expected_final_output = BasicDecoderOutput(
rnn_output=array( rnn_output=ResultSummary(
[[[ shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00083043973),
2.04633363e-03, 1.89259532e-03, 2.09550979e-03, -3.81628517e-03, sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0))
-4.36160620e-03, -6.43933658e-03
], [
2.41885195e-03, 2.02089013e-03, 2.05879519e-03, -3.85483308e-03,
-3.51473060e-03, -6.14458136e-03
], [
2.02294230e-03, 2.06955452e-03, 2.34797411e-03, -3.62816593e-03,
-3.80352931e-03, -6.27150526e-03
]], [[
4.89025004e-03, -1.97221269e-03, 3.34283570e-03,
-2.79326970e-03, 3.63148772e-03, -4.79645561e-03
], [
5.13446378e-03, -2.03941623e-03, 3.51774949e-03,
-2.83448119e-03, 3.14159272e-03, -5.31486655e-03
], [
5.20701287e-03, -2.21262546e-03, 3.58187454e-03,
-2.85831164e-03, 3.20822699e-03, -5.20829484e-03
]], [[
-1.34046993e-03, -9.99792013e-04, -2.11631414e-03,
-1.85202830e-03, -5.26227616e-03, -9.08544939e-03
], [
-1.35486713e-03, -1.04408595e-03, -1.96779310e-03,
-1.80004584e-03, -5.61304903e-03, -9.34211537e-03
], [
-1.12452905e-03, -7.68281636e-04, -1.99770415e-03,
-1.88058324e-03, -5.01882844e-03, -9.32228006e-03
]], [[
1.52967637e-03, -3.97213362e-03, -9.64699371e-04,
8.51419638e-04, -1.29806029e-03, 6.56482670e-03
], [
1.22562144e-03, -4.56351135e-03, -1.08190742e-03,
8.27267300e-04, -2.10060296e-03, 6.43097097e-03
], [
9.93521884e-04, -4.37386986e-03, -1.41534151e-03,
6.44790183e-04, -2.16482091e-03, 6.68301852e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = AttentionWrapperState( expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple( cell_state=LSTMStateTuple(
c=array( c=ResultSummary(
[[ shape=(5, 9), dtype=dtype('float32'), mean=-0.0039763632),
-2.18977481e-02, -8.04181397e-03, -1.48273818e-03, h=ResultSummary(
1.61075518e-02, -1.37986457e-02, -7.57964421e-03, shape=(5, 9), dtype=dtype('float32'), mean=-0.0019849765)),
-8.28644261e-03, -1.18742418e-02, 1.78838037e-02 attention=ResultSummary(
], [ shape=(5, 6), dtype=dtype('float32'), mean=-0.00081052497),
1.74201727e-02, -1.41931782e-02, -3.88098788e-03,
3.19711640e-02, -3.54694054e-02, -2.14694049e-02,
-6.21706853e-03, -1.69323490e-03, -1.94494929e-02
], [
-1.14532551e-02, 8.77828151e-03, -1.62972715e-02,
-1.39963031e-02, 1.34832524e-02, -1.04488730e-02,
6.16201758e-03, -9.41041857e-03, -6.57599326e-03
], [
-4.74753827e-02, -1.19123599e-02, -7.40140676e-05,
4.10552323e-02, -1.36711076e-03, 2.11795494e-02,
-2.80460101e-02, -5.44509329e-02, -2.91906092e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.09847616e-02, -3.97357112e-03, -7.54502777e-04,
7.91223347e-03, -7.02199014e-03, -3.80705344e-03,
-4.22102772e-03, -6.05491130e-03, 8.92073940e-03
], [
8.68115202e-03, -7.16950046e-03, -1.88387593e-03,
1.62680726e-02, -1.76830068e-02, -1.06620435e-02,
-3.07523785e-03, -8.46023730e-04, -9.99386702e-03
], [
-5.71225956e-03, 4.50055022e-03, -8.07653368e-03,
-6.94842264e-03, 6.75687613e-03, -5.12083014e-03,
3.06244940e-03, -4.61752573e-03, -3.23935854e-03
], [
-2.37231534e-02, -5.88526297e-03, -3.72226204e-05,
2.01789513e-02, -6.75848918e-04, 1.06686372e-02,
-1.42624676e-02, -2.69628745e-02, -1.45034352e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.00202294, 0.00206955, 0.00234797, -0.00362817, -0.00380353,
-0.00627151
], [
0.00520701, -0.00221263, 0.00358187, -0.00285831, 0.00320823,
-0.00520829
], [
-0.00112453, -0.00076828, -0.0019977, -0.00188058, -0.00501883,
-0.00932228
], [
0.00099352, -0.00437387, -0.00141534, 0.00064479, -0.00216482,
0.00668302
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
time=3, time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=()) alignment_history=())
expected_final_alignment_history = ResultSummary(
expected_final_alignment_history = [[[ shape=(3, 5, 8), dtype=dtype('float32'), mean=0.12500001)
0.12586178, 0.12272788, 0.1271652, 0.12484902, 0.12484902, 0.12484902,
0.12484902, 0.12484902
], [
0.12612638, 0.12516938, 0.12478404, 0.12478404, 0.12478404, 0.12478404,
0.12478404, 0.12478404
], [
0.12595113, 0.12515794, 0.1255464, 0.1246689, 0.1246689, 0.1246689,
0.1246689, 0.1246689
], [
0.12492912, 0.12501013, 0.12501013, 0.12501013, 0.12501013, 0.12501013,
0.12501013, 0.12501013
], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[
0.12586173, 0.12272781, 0.12716517, 0.12484905, 0.12484905, 0.12484905,
0.12484905, 0.12484905
], [
0.12612617, 0.1251694, 0.12478408, 0.12478408, 0.12478408, 0.12478408,
0.12478408, 0.12478408
], [
0.12595108, 0.12515777, 0.1255464, 0.12466895, 0.12466895, 0.12466895,
0.12466895, 0.12466895
], [
0.12492914, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012,
0.12501012, 0.12501012
], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[
0.12586181, 0.12272815, 0.12716556, 0.12484891, 0.12484891, 0.12484891,
0.12484891, 0.12484891
], [
0.12612608, 0.12516941, 0.12478409, 0.12478409, 0.12478409, 0.12478409,
0.12478409, 0.12478409
], [
0.12595116, 0.12515792, 0.12554643, 0.1246689, 0.1246689, 0.1246689,
0.1246689, 0.1246689
], [
0.1249292, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012,
0.12501012, 0.12501012
], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]]]
self._testWithAttention( self._testWithAttention(
create_attention_mechanism, create_attention_mechanism,
@ -340,263 +213,54 @@ class AttentionWrapperTest(test.TestCase):
expected_final_state, expected_final_state,
alignment_history=True, alignment_history=True,
expected_final_alignment_history=expected_final_alignment_history, expected_final_alignment_history=expected_final_alignment_history,
name="testBahdanauNotNormalized") name='testBahdanauNotNormalized')
def testBahdanauNormalized(self): def testBahdanauNormalized(self):
create_attention_mechanism = functools.partial( create_attention_mechanism = functools.partial(
wrapper.BahdanauAttention, normalize=True) wrapper.BahdanauAttention, normalize=True)
expected_final_output = BasicDecoderOutput( expected_final_output = BasicDecoderOutput(
rnn_output=array( rnn_output=ResultSummary(
[[[ shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00040482997),
1.27064800e-02, 3.57783446e-03, 8.22613202e-03, -1.61504047e-03, sample_id=ResultSummary(
-1.12555185e-02, -3.92740499e-03 shape=(5, 3), dtype=dtype('int32'), mean=1.8666666666666667))
], [
1.30781950e-02, 3.70747922e-03, 8.18992872e-03, -1.65389013e-03,
-1.04098395e-02, -3.63383139e-03
], [
1.26833543e-02, 3.75790196e-03, 8.48123431e-03, -1.42690970e-03,
-1.07016256e-02, -3.76088684e-03
]], [[
6.88417302e-03, -2.04071682e-03, 4.17768257e-03,
-4.51408979e-03, 4.90086433e-03, -6.85973791e-03
], [
7.12782983e-03, -2.10783770e-03, 4.35227761e-03,
-4.55496181e-03, 4.41066315e-03, -7.37757795e-03
], [
7.20011396e-03, -2.28102156e-03, 4.41620918e-03,
-4.57867794e-03, 4.47713351e-03, -7.27072079e-03
]], [[
-2.20676698e-03, -1.43745833e-03, -1.99429039e-03,
-1.44722988e-03, -7.45461835e-03, -9.80243273e-03
], [
-2.22120387e-03, -1.48139545e-03, -1.84528576e-03,
-1.39490096e-03, -7.80559657e-03, -1.00586927e-02
], [
-1.99079141e-03, -1.20571791e-03, -1.87507609e-03,
-1.47541985e-03, -7.21158786e-03, -1.00391749e-02
]], [[
1.48755650e-03, -3.89118027e-03, -9.40889120e-04,
8.36852356e-04, -1.28285377e-03, 6.41521579e-03
], [
1.18351437e-03, -4.48258361e-03, -1.05809816e-03,
8.12723883e-04, -2.08540238e-03, 6.28142804e-03
], [
9.51444614e-04, -4.29300033e-03, -1.39154412e-03,
6.30271854e-04, -2.14963360e-03, 6.53359853e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[0, 0, 0], [0, 0, 0], [1, 3, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = AttentionWrapperState( expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple( cell_state=LSTMStateTuple(
c=array( c=ResultSummary(
[[ shape=(5, 9), dtype=dtype('float32'), mean=-0.0039785588),
-2.19953191e-02, -7.81358499e-03, -1.42740645e-03, h=ResultSummary(
1.62037201e-02, -1.38600282e-02, -7.60386931e-03, shape=(5, 9), dtype=dtype('float32'), mean=-0.0019861322)),
-8.42390209e-03, -1.18884994e-02, 1.78821683e-02 attention=ResultSummary(
], [ shape=(5, 6), dtype=dtype('float32'), mean=-0.00038488387),
1.74096227e-02, -1.41773149e-02, -3.89175024e-03,
3.19635086e-02, -3.54669318e-02, -2.14924756e-02,
-6.20695669e-03, -1.73213519e-03, -1.94583312e-02
], [
-1.14590004e-02, 8.76899902e-03, -1.62825100e-02,
-1.39863417e-02, 1.34333782e-02, -1.04652103e-02,
6.13503950e-03, -9.39247012e-03, -6.57595927e-03
], [
-4.74739373e-02, -1.19136302e-02, -7.36713409e-05,
4.10547927e-02, -1.36768632e-03, 2.11772211e-02,
-2.80480143e-02, -5.44514954e-02, -2.91903671e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.10325804e-02, -3.86056723e-03, -7.26287195e-04,
7.95945339e-03, -7.05253659e-03, -3.81913339e-03,
-4.29130904e-03, -6.06246945e-03, 8.91948957e-03
], [
8.67583323e-03, -7.16136536e-03, -1.88911252e-03,
1.62639488e-02, -1.76817775e-02, -1.06735229e-02,
-3.07015004e-03, -8.65494134e-04, -9.99815390e-03
], [
-5.71519835e-03, 4.49585915e-03, -8.06909613e-03,
-6.94347266e-03, 6.73189852e-03, -5.12895826e-03,
3.04909074e-03, -4.60868096e-03, -3.23936995e-03
], [
-2.37224363e-02, -5.88588836e-03, -3.70502457e-05,
2.01787297e-02, -6.76134136e-04, 1.06674768e-02,
-1.42634623e-02, -2.69631669e-02, -1.45033086e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.01268335, 0.0037579, 0.00848123, -0.00142691, -0.01070163,
-0.00376089
], [
0.00720011, -0.00228102, 0.00441621, -0.00457868, 0.00447713,
-0.00727072
], [
-0.00199079, -0.00120572, -0.00187508, -0.00147542, -0.00721159,
-0.01003917
], [
0.00095144, -0.004293, -0.00139154, 0.00063027, -0.00214963,
0.0065336
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
time=3, time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=()) alignment_history=())
self._testWithAttention( self._testWithAttention(
create_attention_mechanism, create_attention_mechanism,
expected_final_output, expected_final_output,
expected_final_state, expected_final_state,
name="testBahdanauNormalized") name='testBahdanauNormalized')
def testLuongNotNormalized(self): def testLuongNotNormalized(self):
create_attention_mechanism = wrapper.LuongAttention create_attention_mechanism = wrapper.LuongAttention
expected_final_output = BasicDecoderOutput( expected_final_output = BasicDecoderOutput(
rnn_output=array( rnn_output=ResultSummary(
[[[ shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00084602338),
1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03, sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0))
-4.17229906e-03, -6.65769773e-03
], [
1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03,
-3.15058464e-03, -6.34974428e-03
], [
2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03,
-3.96897132e-03, -6.12034509e-03
]], [[
4.76492243e-03, -1.97180966e-03, 3.29327444e-03,
-2.68205139e-03, 3.55229783e-03, -4.66645230e-03
], [
5.24956919e-03, -2.00631656e-03, 3.53828911e-03,
-2.96283513e-03, 3.20920302e-03, -5.43697737e-03
], [
5.30424621e-03, -2.17913301e-03, 3.59509978e-03,
-2.97106663e-03, 3.26450402e-03, -5.31189423e-03
]], [[
-1.36440888e-03, -9.75572329e-04, -2.11284542e-03,
-1.84616144e-03, -5.31351101e-03, -9.12462734e-03
], [
-1.41863467e-03, -1.11081311e-03, -1.94056751e-03,
-1.74311269e-03, -5.76282106e-03, -9.29267984e-03
], [
-1.12129003e-03, -8.15156149e-04, -2.01535341e-03,
-1.89556007e-03, -5.04226238e-03, -9.37188603e-03
]], [[
1.55163277e-03, -4.01433324e-03, -9.77111282e-04,
8.59013060e-04, -1.30598655e-03, 6.64281659e-03
], [
1.26811734e-03, -4.64518648e-03, -1.10593368e-03,
8.41954607e-04, -2.11594440e-03, 6.58190623e-03
], [
1.02682540e-03, -4.43787826e-03, -1.43417739e-03,
6.56281307e-04, -2.17684195e-03, 6.80128345e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = AttentionWrapperState( expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple( cell_state=LSTMStateTuple(
c=array( c=ResultSummary(
[[ shape=(5, 9), dtype=dtype('float32'), mean=-0.0039764317),
-2.18942575e-02, -8.05099495e-03, -1.48526859e-03, h=ResultSummary(
1.61030665e-02, -1.37967104e-02, -7.57982396e-03, shape=(5, 9), dtype=dtype('float32'), mean=-0.0019850098)),
-8.28088820e-03, -1.18743815e-02, 1.78839806e-02 attention=ResultSummary(
], [ shape=(5, 6), dtype=dtype('float32'), mean=-0.00080144603),
1.74203254e-02, -1.41929490e-02, -3.88103351e-03,
3.19709182e-02, -3.54691371e-02, -2.14697979e-02,
-6.21709181e-03, -1.69324467e-03, -1.94495786e-02
], [
-1.14536462e-02, 8.77809525e-03, -1.62965059e-02,
-1.39955431e-02, 1.34810507e-02, -1.04491040e-02,
6.16097450e-03, -9.40943789e-03, -6.57613343e-03
], [
-4.74765450e-02, -1.19113335e-02, -7.42897391e-05,
4.10555862e-02, -1.36665069e-03, 2.11814232e-02,
-2.80444007e-02, -5.44504896e-02, -2.91908123e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.09830676e-02, -3.97811923e-03, -7.55793473e-04,
7.91002903e-03, -7.02103321e-03, -3.80714820e-03,
-4.21818346e-03, -6.05497835e-03, 8.92084371e-03
], [
8.68122280e-03, -7.16937613e-03, -1.88389909e-03,
1.62679367e-02, -1.76828820e-02, -1.06622437e-02,
-3.07524228e-03, -8.46030540e-04, -9.99389403e-03
], [
-5.71245840e-03, 4.50045895e-03, -8.07614625e-03,
-6.94804778e-03, 6.75577158e-03, -5.12094703e-03,
3.06193763e-03, -4.61703911e-03, -3.23943049e-03
], [
-2.37237271e-02, -5.88475820e-03, -3.73612711e-05,
2.01791357e-02, -6.75620860e-04, 1.06695695e-02,
-1.42616741e-02, -2.69626491e-02, -1.45035451e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897,
-0.00612035
], [
0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645,
-0.00531189
], [
-0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226,
-0.00937189
], [
0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684,
0.00680128
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
time=3, time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=()) alignment_history=())
self._testWithAttention( self._testWithAttention(
@ -604,132 +268,27 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output, expected_final_output,
expected_final_state, expected_final_state,
attention_mechanism_depth=9, attention_mechanism_depth=9,
name="testLuongNotNormalized") name='testLuongNotNormalized')
def testLuongScaled(self): def testLuongScaled(self):
create_attention_mechanism = functools.partial( create_attention_mechanism = functools.partial(
wrapper.LuongAttention, scale=True) wrapper.LuongAttention, scale=True)
expected_final_output = BasicDecoderOutput( expected_final_output = BasicDecoderOutput(
rnn_output=array( rnn_output=ResultSummary(
[[[ shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00084602338),
1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03, sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0))
-4.17229906e-03, -6.65769773e-03
], [
1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03,
-3.15058464e-03, -6.34974428e-03
], [
2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03,
-3.96897132e-03, -6.12034509e-03
]], [[
4.76492243e-03, -1.97180966e-03, 3.29327444e-03,
-2.68205139e-03, 3.55229783e-03, -4.66645230e-03
], [
5.24956919e-03, -2.00631656e-03, 3.53828911e-03,
-2.96283513e-03, 3.20920302e-03, -5.43697737e-03
], [
5.30424621e-03, -2.17913301e-03, 3.59509978e-03,
-2.97106663e-03, 3.26450402e-03, -5.31189423e-03
]], [[
-1.36440888e-03, -9.75572329e-04, -2.11284542e-03,
-1.84616144e-03, -5.31351101e-03, -9.12462734e-03
], [
-1.41863467e-03, -1.11081311e-03, -1.94056751e-03,
-1.74311269e-03, -5.76282106e-03, -9.29267984e-03
], [
-1.12129003e-03, -8.15156149e-04, -2.01535341e-03,
-1.89556007e-03, -5.04226238e-03, -9.37188603e-03
]], [[
1.55163277e-03, -4.01433324e-03, -9.77111282e-04,
8.59013060e-04, -1.30598655e-03, 6.64281659e-03
], [
1.26811734e-03, -4.64518648e-03, -1.10593368e-03,
8.41954607e-04, -2.11594440e-03, 6.58190623e-03
], [
1.02682540e-03, -4.43787826e-03, -1.43417739e-03,
6.56281307e-04, -2.17684195e-03, 6.80128345e-03
]], [[
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = AttentionWrapperState( expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple( cell_state=LSTMStateTuple(
c=array( c=ResultSummary(
[[ shape=(5, 9), dtype=dtype('float32'), mean=-0.0039764317),
-2.18942575e-02, -8.05099495e-03, -1.48526859e-03, h=ResultSummary(
1.61030665e-02, -1.37967104e-02, -7.57982396e-03, shape=(5, 9), dtype=dtype('float32'), mean=-0.0019850098)),
-8.28088820e-03, -1.18743815e-02, 1.78839806e-02 attention=ResultSummary(
], [ shape=(5, 6), dtype=dtype('float32'), mean=-0.00080144603),
1.74203254e-02, -1.41929490e-02, -3.88103351e-03,
3.19709182e-02, -3.54691371e-02, -2.14697979e-02,
-6.21709181e-03, -1.69324467e-03, -1.94495786e-02
], [
-1.14536462e-02, 8.77809525e-03, -1.62965059e-02,
-1.39955431e-02, 1.34810507e-02, -1.04491040e-02,
6.16097450e-03, -9.40943789e-03, -6.57613343e-03
], [
-4.74765450e-02, -1.19113335e-02, -7.42897391e-05,
4.10555862e-02, -1.36665069e-03, 2.11814232e-02,
-2.80444007e-02, -5.44504896e-02, -2.91908123e-02
], [
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.09830676e-02, -3.97811923e-03, -7.55793473e-04,
7.91002903e-03, -7.02103321e-03, -3.80714820e-03,
-4.21818346e-03, -6.05497835e-03, 8.92084371e-03
], [
8.68122280e-03, -7.16937613e-03, -1.88389909e-03,
1.62679367e-02, -1.76828820e-02, -1.06622437e-02,
-3.07524228e-03, -8.46030540e-04, -9.99389403e-03
], [
-5.71245840e-03, 4.50045895e-03, -8.07614625e-03,
-6.94804778e-03, 6.75577158e-03, -5.12094703e-03,
3.06193763e-03, -4.61703911e-03, -3.23943049e-03
], [
-2.37237271e-02, -5.88475820e-03, -3.73612711e-05,
2.01791357e-02, -6.75620860e-04, 1.06695695e-02,
-1.42616741e-02, -2.69626491e-02, -1.45035451e-02
], [
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897,
-0.00612035
], [
0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645,
-0.00531189
], [
-0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226,
-0.00937189
], [
0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684,
0.00680128
], [
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32),
time=3, time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=()) alignment_history=())
self._testWithAttention( self._testWithAttention(
@ -737,116 +296,27 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output, expected_final_output,
expected_final_state, expected_final_state,
attention_mechanism_depth=9, attention_mechanism_depth=9,
name="testLuongScaled") name='testLuongScaled')
def testNotUseAttentionLayer(self): def testNotUseAttentionLayer(self):
create_attention_mechanism = wrapper.BahdanauAttention create_attention_mechanism = wrapper.BahdanauAttention
expected_final_output = BasicDecoderOutput( expected_final_output = BasicDecoderOutput(
rnn_output=array( rnn_output=ResultSummary(
[[[ shape=(5, 3, 10), dtype=dtype('float32'), mean=0.019546926),
-0.24223405, -0.07791166, 0.15451428, 0.24738294, 0.30900395, sample_id=ResultSummary(
-0.24685201, 0.04992372, 0.18749543, -0.15878429, -0.13678923 shape=(5, 3), dtype=dtype('int32'), mean=2.7999999999999998))
], [
-0.2422339, -0.07791159, 0.15451418, 0.24738279, 0.30900383,
-0.24685188, 0.04992369, 0.18749531, -0.15878411, -0.13678911
], [
-0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475,
-0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965
]], [[
0.40035266, 0.12299616, -0.06085059, -0.09197108, 0.11368551,
-0.15302914, 0.00566157, -0.26885766, 0.08546552, 0.18886778
], [
0.40035242, 0.12299603, -0.06085056, -0.09197091, 0.11368536,
-0.15302882, 0.0056615, -0.26885763, 0.08546554, 0.18886763
], [
0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532,
-0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761
]], [[
-0.4311333, 0.07519469, -0.01551808, 0.1913045, -0.02693807,
-0.21668895, -0.02155721, 0.0013397, 0.21180844, 0.25578707
], [
-0.43113309, 0.07519454, -0.01551818, 0.19130446, -0.0269379,
-0.21668854, -0.021557, 0.00133975, 0.21180828, 0.25578681
], [
-0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798,
-0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696
]], [[
0.07059932, 0.16451572, 0.01174669, 0.04646531, 0.1427598,
0.0794456, -0.10852993, 0.15306188, 0.02151393, -0.05590061
], [
0.07059933, 0.16451576, 0.01174669, 0.04646532, 0.14275983,
0.07944562, -0.10852996, 0.15306193, 0.02151394, -0.05590062
], [
0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599,
0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065
]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
dtype=float32),
sample_id=array(
[[4, 4, 4], [0, 0, 0], [9, 9, 9], [1, 1, 1], [0, 0, 0]],
dtype=int32))
expected_final_state = AttentionWrapperState( expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple( cell_state=LSTMStateTuple(
c=array( c=ResultSummary(
[[ shape=(5, 9), dtype=dtype('float32'), mean=-0.0041728448),
-0.0181195, -0.01675365, -0.00510353, 0.01559796, h=ResultSummary(
-0.01251448, -0.00437002, -0.01243257, -0.01720199, shape=(5, 9), dtype=dtype('float32'), mean=-0.002085865)),
0.02274928 attention=ResultSummary(
], [ shape=(5, 10), dtype=dtype('float32'), mean=0.019546915),
0.01259979, -0.00839985, -0.00374037, 0.03136262,
-0.03486227, -0.02466441, -0.00496157, -0.00461032,
-0.02098336
], [
-0.00781067, 0.00315682, -0.0138283, -0.01149793,
0.00485562, -0.01343193, 0.0085915, -0.00632846, -0.01052086
], [
-0.04184828, -0.01223641, 0.0009445, 0.03911434, 0.0043249,
0.02220661, -0.03006243, -0.05418363, -0.02615385
], [
0.02282745, -0.00143833, 0.01918138, 0.00545033,
-0.01258384, -0.00303765, -0.00511231, 0.02166323,
0.01638841
]],
dtype=float32),
h=array(
[[
-0.00910065, -0.00827571, -0.00259689, 0.00764857,
-0.00635579, -0.00218579, -0.00633918, -0.00875511,
0.01134532
], [
0.00626597, -0.004241, -0.00181303, 0.01597157, -0.0173375,
-0.01224921, -0.00244522, -0.00231299, -0.0107822
], [
-0.00391383, 0.00162017, -0.00682621, -0.00570264,
0.00244099, -0.00659772, 0.00426475, -0.00309861,
-0.00520028
], [
-0.02087484, -0.00603306, 0.00047561, 0.01920062,
0.00213875, 0.01115329, -0.0152659, -0.02687523, -0.01297523
], [
0.01138975, -0.00070959, 0.00986007, 0.0027323, -0.00645386,
-0.00152054, -0.00257339, 0.01103063, 0.00800891
]],
dtype=float32)),
attention=array(
[[
-0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475,
-0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965
], [
0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532,
-0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761
], [
-0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798,
-0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696
], [
0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599,
0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065
], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
dtype=float32),
time=3, time=3,
alignments=ResultSummary(
shape=(5, 8), dtype=dtype('float32'), mean=0.125),
alignment_history=()) alignment_history=())
self._testWithAttention( self._testWithAttention(
@ -854,8 +324,8 @@ class AttentionWrapperTest(test.TestCase):
expected_final_output, expected_final_output,
expected_final_state, expected_final_state,
attention_layer_size=None, attention_layer_size=None,
name="testNotUseAttentionLayer") name='testNotUseAttentionLayer')
if __name__ == "__main__": if __name__ == '__main__':
test.main() test.main()

View File

@ -121,8 +121,13 @@ class _BaseAttentionMechanism(AttentionMechanism):
2. Preprocessing and storing the memory. 2. Preprocessing and storing the memory.
""" """
def __init__(self, query_layer, memory, memory_sequence_length=None, def __init__(self,
memory_layer=None, check_inner_dims_defined=True, query_layer,
memory,
probability_fn,
memory_sequence_length=None,
memory_layer=None,
check_inner_dims_defined=True,
name=None): name=None):
"""Construct base AttentionMechanism class. """Construct base AttentionMechanism class.
@ -132,6 +137,9 @@ class _BaseAttentionMechanism(AttentionMechanism):
provided, the shape of `query` must match that of `memory_layer`. provided, the shape of `query` must match that of `memory_layer`.
memory: The memory to query; usually the output of an RNN encoder. This memory: The memory to query; usually the output of an RNN encoder. This
tensor should be shaped `[batch_size, max_time, ...]`. tensor should be shaped `[batch_size, max_time, ...]`.
probability_fn: A `callable`. Converts the score and previous alignments
to probabilities. Its signature should be:
`probabilities = probability_fn(score, previous_alignments)`.
memory_sequence_length (optional): Sequence lengths for the batch entries memory_sequence_length (optional): Sequence lengths for the batch entries
in memory. If provided, the memory tensor rows are masked with zeros in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths. for values past the respective sequence lengths.
@ -154,6 +162,10 @@ class _BaseAttentionMechanism(AttentionMechanism):
"memory_layer is not a Layer: %s" % type(memory_layer).__name__) "memory_layer is not a Layer: %s" % type(memory_layer).__name__)
self._query_layer = query_layer self._query_layer = query_layer
self._memory_layer = memory_layer self._memory_layer = memory_layer
if not callable(probability_fn):
raise TypeError("probability_fn must be callable, saw type: %s" %
type(probability_fn).__name__)
self._probability_fn = probability_fn
with ops.name_scope( with ops.name_scope(
name, "BaseAttentionMechanismInit", nest.flatten(memory)): name, "BaseAttentionMechanismInit", nest.flatten(memory)):
self._values = _prepare_memory( self._values = _prepare_memory(
@ -164,6 +176,8 @@ class _BaseAttentionMechanism(AttentionMechanism):
else self._values) else self._values)
self._batch_size = ( self._batch_size = (
self._keys.shape[0].value or array_ops.shape(self._keys)[0]) self._keys.shape[0].value or array_ops.shape(self._keys)[0])
self._alignments_size = (self._keys.shape[1].value or
array_ops.shape(self._keys)[1])
@property @property
def memory_layer(self): def memory_layer(self):
@ -185,6 +199,29 @@ class _BaseAttentionMechanism(AttentionMechanism):
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
@property
def alignments_size(self):
return self._alignments_size
def initial_alignments(self, batch_size, dtype):
"""Creates the initial alignment values for the `AttentionWrapper` class.
This is important for AttentionMechanisms that use the previous alignment
to calculate the alignment at the next time step (e.g. monotonic attention).
The default behavior is to return a tensor of all zeros.
Args:
batch_size: `int32` scalar, the batch_size.
dtype: The `dtype`.
Returns:
A `dtype` tensor shaped `[batch_size, alignments_size]`
(`alignments_size` is the values' `max_time`).
"""
max_time = self._alignments_size
return _zero_state_tensors(max_time, batch_size, dtype)
class LuongAttention(_BaseAttentionMechanism): class LuongAttention(_BaseAttentionMechanism):
"""Implements Luong-style (multiplicative) attention scoring. """Implements Luong-style (multiplicative) attention scoring.
@ -208,6 +245,7 @@ class LuongAttention(_BaseAttentionMechanism):
memory, memory,
memory_sequence_length=None, memory_sequence_length=None,
scale=False, scale=False,
probability_fn=None,
name="LuongAttention"): name="LuongAttention"):
"""Construct the AttentionMechanism mechanism. """Construct the AttentionMechanism mechanism.
@ -219,31 +257,43 @@ class LuongAttention(_BaseAttentionMechanism):
in memory. If provided, the memory tensor rows are masked with zeros in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths. for values past the respective sequence lengths.
scale: Python boolean. Whether to scale the energy term. scale: Python boolean. Whether to scale the energy term.
probability_fn: (optional) A `callable`. Converts the score to
probabilities. The default is @{tf.nn.softmax}. Other options include
@{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
Its signature should be: `probabilities = probability_fn(score)`.
name: Name to use when creating ops. name: Name to use when creating ops.
""" """
# For LuongAttention, we only transform the memory layer; thus # For LuongAttention, we only transform the memory layer; thus
# num_units **must** match expected the query depth. # num_units **must** match expected the query depth.
if probability_fn is None:
probability_fn = nn_ops.softmax
wrapped_probability_fn = lambda score, _: probability_fn(score)
super(LuongAttention, self).__init__( super(LuongAttention, self).__init__(
query_layer=None, query_layer=None,
memory_layer=layers_core.Dense( memory_layer=layers_core.Dense(
num_units, name="memory_layer", use_bias=False), num_units, name="memory_layer", use_bias=False),
memory=memory, memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length, memory_sequence_length=memory_sequence_length,
name=name) name=name)
self._num_units = num_units self._num_units = num_units
self._scale = scale self._scale = scale
self._name = name self._name = name
def __call__(self, query): def __call__(self, query, previous_alignments):
"""Score the query based on the keys and values. """Score the query based on the keys and values.
Args: Args:
query: Tensor of dtype matching `self.values` and shape query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`. `[batch_size, query_depth]`.
previous_alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns: Returns:
score: Tensor of dtype matching `self.values` and shape alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, max_time]` (`max_time` is memory's `max_time`). `[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
Raises: Raises:
ValueError: If `key` and `query` depths do not match. ValueError: If `key` and `query` depths do not match.
@ -281,7 +331,8 @@ class LuongAttention(_BaseAttentionMechanism):
"attention_g", dtype=dtype, initializer=1.) "attention_g", dtype=dtype, initializer=1.)
score = g * score score = g * score
return score alignments = self._probability_fn(score, previous_alignments)
return alignments
class BahdanauAttention(_BaseAttentionMechanism): class BahdanauAttention(_BaseAttentionMechanism):
@ -311,6 +362,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
memory, memory,
memory_sequence_length=None, memory_sequence_length=None,
normalize=False, normalize=False,
probability_fn=None,
name="BahdanauAttention"): name="BahdanauAttention"):
"""Construct the Attention mechanism. """Construct the Attention mechanism.
@ -322,30 +374,42 @@ class BahdanauAttention(_BaseAttentionMechanism):
in memory. If provided, the memory tensor rows are masked with zeros in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths. for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term. normalize: Python boolean. Whether to normalize the energy term.
probability_fn: (optional) A `callable`. Converts the score to
probabilities. The default is @{tf.nn.softmax}. Other options include
@{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
Its signature should be: `probabilities = probability_fn(score)`.
name: Name to use when creating ops. name: Name to use when creating ops.
""" """
if probability_fn is None:
probability_fn = nn_ops.softmax
wrapped_probability_fn = lambda score, _: probability_fn(score)
super(BahdanauAttention, self).__init__( super(BahdanauAttention, self).__init__(
query_layer=layers_core.Dense( query_layer=layers_core.Dense(
num_units, name="query_layer", use_bias=False), num_units, name="query_layer", use_bias=False),
memory_layer=layers_core.Dense( memory_layer=layers_core.Dense(
num_units, name="memory_layer", use_bias=False), num_units, name="memory_layer", use_bias=False),
memory=memory, memory=memory,
probability_fn=wrapped_probability_fn,
memory_sequence_length=memory_sequence_length, memory_sequence_length=memory_sequence_length,
name=name) name=name)
self._num_units = num_units self._num_units = num_units
self._normalize = normalize self._normalize = normalize
self._name = name self._name = name
def __call__(self, query): def __call__(self, query, previous_alignments):
"""Score the query based on the keys and values. """Score the query based on the keys and values.
Args: Args:
query: Tensor of dtype matching `self.values` and shape query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`. `[batch_size, query_depth]`.
previous_alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns: Returns:
score: Tensor of dtype matching `self.values` and shape alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, max_time]` (`max_time` is memory's `max_time`). `[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
""" """
with variable_scope.variable_scope(None, "bahdanau_attention", [query]): with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query processed_query = self.query_layer(query) if self.query_layer else query
@ -373,20 +437,23 @@ class BahdanauAttention(_BaseAttentionMechanism):
score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query),
[2]) [2])
return score alignments = self._probability_fn(score, previous_alignments)
return alignments
class AttentionWrapperState( class AttentionWrapperState(
collections.namedtuple("AttentionWrapperState", collections.namedtuple("AttentionWrapperState",
("cell_state", "attention", "time", ("cell_state", "attention", "time", "alignments",
"alignment_history"))): "alignment_history"))):
"""`namedtuple` storing the state of a `AttentionWrapper`. """`namedtuple` storing the state of a `AttentionWrapper`.
Contains: Contains:
- `cell_state`: The state of the wrapped `RNNCell`. - `cell_state`: The state of the wrapped `RNNCell` at the previous time
step.
- `attention`: The attention emitted at the previous time step. - `attention`: The attention emitted at the previous time step.
- `time`: int32 scalar containing the current time step. - `time`: int32 scalar containing the current time step.
- `alignments`: The alignment emitted at the previous time step.
- `alignment_history`: (if enabled) a `TensorArray` containing alignment - `alignment_history`: (if enabled) a `TensorArray` containing alignment
matrices from all time steps. Call `stack()` to convert to a `Tensor`. matrices from all time steps. Call `stack()` to convert to a `Tensor`.
""" """
@ -443,7 +510,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
attention_layer_size=None, attention_layer_size=None,
alignment_history=False, alignment_history=False,
cell_input_fn=None, cell_input_fn=None,
probability_fn=None,
output_attention=True, output_attention=True,
initial_cell_state=None, initial_cell_state=None,
name=None): name=None):
@ -461,9 +527,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
time major `TensorArray` on which you must call `stack()`). time major `TensorArray` on which you must call `stack()`).
cell_input_fn: (optional) A `callable`. The default is: cell_input_fn: (optional) A `callable`. The default is:
`lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
probability_fn: (optional) A `callable`. Converts the score to
probabilities. The default is @{tf.nn.softmax}. Other options include
@{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
output_attention: Python bool. If `True` (default), the output at each output_attention: Python bool. If `True` (default), the output at each
time step is the attention value. This is the behavior of Luong-style time step is the attention value. This is the behavior of Luong-style
attention mechanisms. If `False`, the output at each time step is attention mechanisms. If `False`, the output at each time step is
@ -495,13 +558,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
raise TypeError( raise TypeError(
"cell_input_fn must be callable, saw type: %s" "cell_input_fn must be callable, saw type: %s"
% type(cell_input_fn).__name__) % type(cell_input_fn).__name__)
if probability_fn is None:
probability_fn = nn_ops.softmax
else:
if not callable(probability_fn):
raise TypeError(
"probability_fn must be callable, saw type: %s"
% type(probability_fn).__name__)
if attention_layer_size is not None: if attention_layer_size is not None:
self._attention_layer = layers_core.Dense( self._attention_layer = layers_core.Dense(
@ -514,7 +570,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
self._cell = cell self._cell = cell
self._attention_mechanism = attention_mechanism self._attention_mechanism = attention_mechanism
self._cell_input_fn = cell_input_fn self._cell_input_fn = cell_input_fn
self._probability_fn = probability_fn
self._output_attention = output_attention self._output_attention = output_attention
self._alignment_history = alignment_history self._alignment_history = alignment_history
with ops.name_scope(name, "AttentionWrapperInit"): with ops.name_scope(name, "AttentionWrapperInit"):
@ -553,6 +608,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
cell_state=self._cell.state_size, cell_state=self._cell.state_size,
time=tensor_shape.TensorShape([]), time=tensor_shape.TensorShape([]),
attention=self._attention_size, attention=self._attention_size,
alignments=self._attention_mechanism.alignments_size,
alignment_history=()) # alignment_history is sometimes a TensorArray alignment_history=()) # alignment_history is sometimes a TensorArray
def zero_state(self, batch_size, dtype): def zero_state(self, batch_size, dtype):
@ -586,6 +642,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
time=array_ops.zeros([], dtype=dtypes.int32), time=array_ops.zeros([], dtype=dtypes.int32),
attention=_zero_state_tensors(self._attention_size, batch_size, attention=_zero_state_tensors(self._attention_size, batch_size,
dtype), dtype),
alignments=self._attention_mechanism.initial_alignments(
batch_size, dtype),
alignment_history=alignment_history) alignment_history=alignment_history)
def call(self, inputs, state): def call(self, inputs, state):
@ -637,8 +695,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
cell_output = array_ops.identity( cell_output = array_ops.identity(
cell_output, name="checked_cell_output") cell_output, name="checked_cell_output")
score = self._attention_mechanism(cell_output) alignments = self._attention_mechanism(
alignments = self._probability_fn(score) cell_output, previous_alignments=state.alignments)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1) expanded_alignments = array_ops.expand_dims(alignments, 1)
@ -671,6 +729,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
time=state.time + 1, time=state.time + 1,
cell_state=next_cell_state, cell_state=next_cell_state,
attention=attention, attention=attention,
alignments=alignments,
alignment_history=alignment_history) alignment_history=alignment_history)
if self._output_attention: if self._output_attention: