From 49a259db1517834be954e037497564ea70ce56a1 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Fri, 12 May 2017 15:39:26 -0700 Subject: [PATCH] [tf contrib seq2seq] Updates to AttentionMechanism API * Move the probability_fn into the AttentionMechanism and out of AttentionWrapper. * Propagate the previous alignment through the attention state and pass it to the AttentionMechanism. * Simplify the unit tests. Necessary for colin's monotonic attention. PiperOrigin-RevId: 155920141 --- .../kernel_tests/attention_wrapper_test.py | 742 +++--------------- .../seq2seq/python/ops/attention_wrapper.py | 111 ++- 2 files changed, 191 insertions(+), 662 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 40b50338adc..b8b420e10a7 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: enable=unused-import -import sys +import collections import functools import numpy as np @@ -46,15 +46,27 @@ BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid float32 = np.float32 int32 = np.int32 array = np.array +dtype = np.dtype + + +class ResultSummary( + collections.namedtuple('ResultSummary', ('shape', 'dtype', 'mean'))): + pass + + +def get_result_summary(x): + if isinstance(x, np.ndarray): + return ResultSummary(x.shape, x.dtype, x.mean()) + return x class AttentionWrapperTest(test.TestCase): - def assertAllClose(self, *args, **kwargs): - kwargs["atol"] = 1e-4 # For GPU tests - kwargs["rtol"] = 1e-4 # For GPU tests - return super(AttentionWrapperTest, self).assertAllClose( - *args, **kwargs) + def assertAllCloseOrEqual(self, x, y, **kwargs): + if isinstance(x, np.ndarray) or isinstance(x, float): + return super(AttentionWrapperTest, self).assertAllClose(x, y, **kwargs) + else: + self.assertAllEqual(x, y, **kwargs) def testAttentionWrapperState(self): num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access @@ -71,7 +83,7 @@ class AttentionWrapperTest(test.TestCase): alignment_history=False, expected_final_alignment_history=None, attention_layer_size=6, - name=""): + name=''): encoder_sequence_length = [3, 2, 3, 1, 0] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 @@ -98,7 +110,7 @@ class AttentionWrapperTest(test.TestCase): with self.test_session(use_gpu=True) as sess: with vs.variable_scope( - "root", + 'root', initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)): cell = core_rnn_cell.LSTMCell(cell_depth) cell = wrapper.AttentionWrapper( @@ -147,192 +159,53 @@ class AttentionWrapperTest(test.TestCase): sess.run(variables.global_variables_initializer()) sess_results = sess.run({ - "final_outputs": final_outputs, - "final_state": final_state, - "state_alignment_history": state_alignment_history, + 'final_outputs': final_outputs, + 'final_state': final_state, + 'state_alignment_history': state_alignment_history, }) - print("Copy/paste (%s)\nexpected_final_output = " % name, - sess_results["final_outputs"]) - sys.stdout.flush() - print("Copy/paste (%s)\nexpected_final_state = " % name, - sess_results["final_state"]) - sys.stdout.flush() - print("Copy/paste (%s)\nexpected_final_alignment_history = " % name, - np.asarray(sess_results["state_alignment_history"])) - sys.stdout.flush() - nest.map_structure(self.assertAllClose, expected_final_output, - sess_results["final_outputs"]) - nest.map_structure(self.assertAllClose, expected_final_state, - sess_results["final_state"]) + final_output_info = nest.map_structure(get_result_summary, + sess_results['final_outputs']) + final_state_info = nest.map_structure(get_result_summary, + sess_results['final_state']) + print('Copy/paste:\nexpected_final_output = %s' % str(final_output_info)) + print('expected_final_state = %s' % str(final_state_info)) + nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, + final_output_info) + nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, + final_state_info) if alignment_history: # by default, the wrapper emits attention as output - self.assertAllClose( + final_alignment_history_info = nest.map_structure( + get_result_summary, sess_results['state_alignment_history']) + print('expected_final_alignment_history = %s' % + str(final_alignment_history_info)) + nest.map_structure( + self.assertAllCloseOrEqual, # outputs are batch major but the stacked TensorArray is time major - sess_results["state_alignment_history"], - expected_final_alignment_history) + expected_final_alignment_history, + final_alignment_history_info) def testBahdanauNotNormalized(self): create_attention_mechanism = wrapper.BahdanauAttention expected_final_output = BasicDecoderOutput( - rnn_output=array( - [[[ - 2.04633363e-03, 1.89259532e-03, 2.09550979e-03, -3.81628517e-03, - -4.36160620e-03, -6.43933658e-03 - ], [ - 2.41885195e-03, 2.02089013e-03, 2.05879519e-03, -3.85483308e-03, - -3.51473060e-03, -6.14458136e-03 - ], [ - 2.02294230e-03, 2.06955452e-03, 2.34797411e-03, -3.62816593e-03, - -3.80352931e-03, -6.27150526e-03 - ]], [[ - 4.89025004e-03, -1.97221269e-03, 3.34283570e-03, - -2.79326970e-03, 3.63148772e-03, -4.79645561e-03 - ], [ - 5.13446378e-03, -2.03941623e-03, 3.51774949e-03, - -2.83448119e-03, 3.14159272e-03, -5.31486655e-03 - ], [ - 5.20701287e-03, -2.21262546e-03, 3.58187454e-03, - -2.85831164e-03, 3.20822699e-03, -5.20829484e-03 - ]], [[ - -1.34046993e-03, -9.99792013e-04, -2.11631414e-03, - -1.85202830e-03, -5.26227616e-03, -9.08544939e-03 - ], [ - -1.35486713e-03, -1.04408595e-03, -1.96779310e-03, - -1.80004584e-03, -5.61304903e-03, -9.34211537e-03 - ], [ - -1.12452905e-03, -7.68281636e-04, -1.99770415e-03, - -1.88058324e-03, -5.01882844e-03, -9.32228006e-03 - ]], [[ - 1.52967637e-03, -3.97213362e-03, -9.64699371e-04, - 8.51419638e-04, -1.29806029e-03, 6.56482670e-03 - ], [ - 1.22562144e-03, -4.56351135e-03, -1.08190742e-03, - 8.27267300e-04, -2.10060296e-03, 6.43097097e-03 - ], [ - 9.93521884e-04, -4.37386986e-03, -1.41534151e-03, - 6.44790183e-04, -2.16482091e-03, 6.68301852e-03 - ]], [[ - -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, - -1.56512906e-04, 9.63474595e-05 - ], [ - -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, - 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 - ], [ - 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, - 2.73401442e-04, -2.69805576e-04 - ]]], - dtype=float32), - sample_id=array( - [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], - dtype=int32)) - + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00083043973), + sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=array( - [[ - -2.18977481e-02, -8.04181397e-03, -1.48273818e-03, - 1.61075518e-02, -1.37986457e-02, -7.57964421e-03, - -8.28644261e-03, -1.18742418e-02, 1.78838037e-02 - ], [ - 1.74201727e-02, -1.41931782e-02, -3.88098788e-03, - 3.19711640e-02, -3.54694054e-02, -2.14694049e-02, - -6.21706853e-03, -1.69323490e-03, -1.94494929e-02 - ], [ - -1.14532551e-02, 8.77828151e-03, -1.62972715e-02, - -1.39963031e-02, 1.34832524e-02, -1.04488730e-02, - 6.16201758e-03, -9.41041857e-03, -6.57599326e-03 - ], [ - -4.74753827e-02, -1.19123599e-02, -7.40140676e-05, - 4.10552323e-02, -1.36711076e-03, 2.11795494e-02, - -2.80460101e-02, -5.44509329e-02, -2.91906092e-02 - ], [ - 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, - 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, - -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 - ]], - dtype=float32), - h=array( - [[ - -1.09847616e-02, -3.97357112e-03, -7.54502777e-04, - 7.91223347e-03, -7.02199014e-03, -3.80705344e-03, - -4.22102772e-03, -6.05491130e-03, 8.92073940e-03 - ], [ - 8.68115202e-03, -7.16950046e-03, -1.88387593e-03, - 1.62680726e-02, -1.76830068e-02, -1.06620435e-02, - -3.07523785e-03, -8.46023730e-04, -9.99386702e-03 - ], [ - -5.71225956e-03, 4.50055022e-03, -8.07653368e-03, - -6.94842264e-03, 6.75687613e-03, -5.12083014e-03, - 3.06244940e-03, -4.61752573e-03, -3.23935854e-03 - ], [ - -2.37231534e-02, -5.88526297e-03, -3.72226204e-05, - 2.01789513e-02, -6.75848918e-04, 1.06686372e-02, - -1.42624676e-02, -2.69628745e-02, -1.45034352e-02 - ], [ - 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, - 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, - -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 - ]], - dtype=float32)), - attention=array( - [[ - 0.00202294, 0.00206955, 0.00234797, -0.00362817, -0.00380353, - -0.00627151 - ], [ - 0.00520701, -0.00221263, 0.00358187, -0.00285831, 0.00320823, - -0.00520829 - ], [ - -0.00112453, -0.00076828, -0.0019977, -0.00188058, -0.00501883, - -0.00932228 - ], [ - 0.00099352, -0.00437387, -0.00141534, 0.00064479, -0.00216482, - 0.00668302 - ], [ - 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, - -0.00026981 - ]], - dtype=float32), + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0039763632), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0019849765)), + attention=ResultSummary( + shape=(5, 6), dtype=dtype('float32'), mean=-0.00081052497), time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) - - expected_final_alignment_history = [[[ - 0.12586178, 0.12272788, 0.1271652, 0.12484902, 0.12484902, 0.12484902, - 0.12484902, 0.12484902 - ], [ - 0.12612638, 0.12516938, 0.12478404, 0.12478404, 0.12478404, 0.12478404, - 0.12478404, 0.12478404 - ], [ - 0.12595113, 0.12515794, 0.1255464, 0.1246689, 0.1246689, 0.1246689, - 0.1246689, 0.1246689 - ], [ - 0.12492912, 0.12501013, 0.12501013, 0.12501013, 0.12501013, 0.12501013, - 0.12501013, 0.12501013 - ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[ - 0.12586173, 0.12272781, 0.12716517, 0.12484905, 0.12484905, 0.12484905, - 0.12484905, 0.12484905 - ], [ - 0.12612617, 0.1251694, 0.12478408, 0.12478408, 0.12478408, 0.12478408, - 0.12478408, 0.12478408 - ], [ - 0.12595108, 0.12515777, 0.1255464, 0.12466895, 0.12466895, 0.12466895, - 0.12466895, 0.12466895 - ], [ - 0.12492914, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012, - 0.12501012, 0.12501012 - ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[ - 0.12586181, 0.12272815, 0.12716556, 0.12484891, 0.12484891, 0.12484891, - 0.12484891, 0.12484891 - ], [ - 0.12612608, 0.12516941, 0.12478409, 0.12478409, 0.12478409, 0.12478409, - 0.12478409, 0.12478409 - ], [ - 0.12595116, 0.12515792, 0.12554643, 0.1246689, 0.1246689, 0.1246689, - 0.1246689, 0.1246689 - ], [ - 0.1249292, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012, - 0.12501012, 0.12501012 - ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]]] + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=dtype('float32'), mean=0.12500001) self._testWithAttention( create_attention_mechanism, @@ -340,263 +213,54 @@ class AttentionWrapperTest(test.TestCase): expected_final_state, alignment_history=True, expected_final_alignment_history=expected_final_alignment_history, - name="testBahdanauNotNormalized") + name='testBahdanauNotNormalized') def testBahdanauNormalized(self): create_attention_mechanism = functools.partial( wrapper.BahdanauAttention, normalize=True) expected_final_output = BasicDecoderOutput( - rnn_output=array( - [[[ - 1.27064800e-02, 3.57783446e-03, 8.22613202e-03, -1.61504047e-03, - -1.12555185e-02, -3.92740499e-03 - ], [ - 1.30781950e-02, 3.70747922e-03, 8.18992872e-03, -1.65389013e-03, - -1.04098395e-02, -3.63383139e-03 - ], [ - 1.26833543e-02, 3.75790196e-03, 8.48123431e-03, -1.42690970e-03, - -1.07016256e-02, -3.76088684e-03 - ]], [[ - 6.88417302e-03, -2.04071682e-03, 4.17768257e-03, - -4.51408979e-03, 4.90086433e-03, -6.85973791e-03 - ], [ - 7.12782983e-03, -2.10783770e-03, 4.35227761e-03, - -4.55496181e-03, 4.41066315e-03, -7.37757795e-03 - ], [ - 7.20011396e-03, -2.28102156e-03, 4.41620918e-03, - -4.57867794e-03, 4.47713351e-03, -7.27072079e-03 - ]], [[ - -2.20676698e-03, -1.43745833e-03, -1.99429039e-03, - -1.44722988e-03, -7.45461835e-03, -9.80243273e-03 - ], [ - -2.22120387e-03, -1.48139545e-03, -1.84528576e-03, - -1.39490096e-03, -7.80559657e-03, -1.00586927e-02 - ], [ - -1.99079141e-03, -1.20571791e-03, -1.87507609e-03, - -1.47541985e-03, -7.21158786e-03, -1.00391749e-02 - ]], [[ - 1.48755650e-03, -3.89118027e-03, -9.40889120e-04, - 8.36852356e-04, -1.28285377e-03, 6.41521579e-03 - ], [ - 1.18351437e-03, -4.48258361e-03, -1.05809816e-03, - 8.12723883e-04, -2.08540238e-03, 6.28142804e-03 - ], [ - 9.51444614e-04, -4.29300033e-03, -1.39154412e-03, - 6.30271854e-04, -2.14963360e-03, 6.53359853e-03 - ]], [[ - -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, - -1.56512906e-04, 9.63474595e-05 - ], [ - -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, - 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 - ], [ - 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, - 2.73401442e-04, -2.69805576e-04 - ]]], - dtype=float32), - sample_id=array( - [[0, 0, 0], [0, 0, 0], [1, 3, 1], [5, 5, 5], [3, 3, 2]], - dtype=int32)) - + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00040482997), + sample_id=ResultSummary( + shape=(5, 3), dtype=dtype('int32'), mean=1.8666666666666667)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=array( - [[ - -2.19953191e-02, -7.81358499e-03, -1.42740645e-03, - 1.62037201e-02, -1.38600282e-02, -7.60386931e-03, - -8.42390209e-03, -1.18884994e-02, 1.78821683e-02 - ], [ - 1.74096227e-02, -1.41773149e-02, -3.89175024e-03, - 3.19635086e-02, -3.54669318e-02, -2.14924756e-02, - -6.20695669e-03, -1.73213519e-03, -1.94583312e-02 - ], [ - -1.14590004e-02, 8.76899902e-03, -1.62825100e-02, - -1.39863417e-02, 1.34333782e-02, -1.04652103e-02, - 6.13503950e-03, -9.39247012e-03, -6.57595927e-03 - ], [ - -4.74739373e-02, -1.19136302e-02, -7.36713409e-05, - 4.10547927e-02, -1.36768632e-03, 2.11772211e-02, - -2.80480143e-02, -5.44514954e-02, -2.91903671e-02 - ], [ - 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, - 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, - -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 - ]], - dtype=float32), - h=array( - [[ - -1.10325804e-02, -3.86056723e-03, -7.26287195e-04, - 7.95945339e-03, -7.05253659e-03, -3.81913339e-03, - -4.29130904e-03, -6.06246945e-03, 8.91948957e-03 - ], [ - 8.67583323e-03, -7.16136536e-03, -1.88911252e-03, - 1.62639488e-02, -1.76817775e-02, -1.06735229e-02, - -3.07015004e-03, -8.65494134e-04, -9.99815390e-03 - ], [ - -5.71519835e-03, 4.49585915e-03, -8.06909613e-03, - -6.94347266e-03, 6.73189852e-03, -5.12895826e-03, - 3.04909074e-03, -4.60868096e-03, -3.23936995e-03 - ], [ - -2.37224363e-02, -5.88588836e-03, -3.70502457e-05, - 2.01787297e-02, -6.76134136e-04, 1.06674768e-02, - -1.42634623e-02, -2.69631669e-02, -1.45033086e-02 - ], [ - 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, - 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, - -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 - ]], - dtype=float32)), - attention=array( - [[ - 0.01268335, 0.0037579, 0.00848123, -0.00142691, -0.01070163, - -0.00376089 - ], [ - 0.00720011, -0.00228102, 0.00441621, -0.00457868, 0.00447713, - -0.00727072 - ], [ - -0.00199079, -0.00120572, -0.00187508, -0.00147542, -0.00721159, - -0.01003917 - ], [ - 0.00095144, -0.004293, -0.00139154, 0.00063027, -0.00214963, - 0.0065336 - ], [ - 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, - -0.00026981 - ]], - dtype=float32), + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0039785588), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0019861322)), + attention=ResultSummary( + shape=(5, 6), dtype=dtype('float32'), mean=-0.00038488387), time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( create_attention_mechanism, expected_final_output, expected_final_state, - name="testBahdanauNormalized") + name='testBahdanauNormalized') def testLuongNotNormalized(self): create_attention_mechanism = wrapper.LuongAttention expected_final_output = BasicDecoderOutput( - rnn_output=array( - [[[ - 1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03, - -4.17229906e-03, -6.65769773e-03 - ], [ - 1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03, - -3.15058464e-03, -6.34974428e-03 - ], [ - 2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03, - -3.96897132e-03, -6.12034509e-03 - ]], [[ - 4.76492243e-03, -1.97180966e-03, 3.29327444e-03, - -2.68205139e-03, 3.55229783e-03, -4.66645230e-03 - ], [ - 5.24956919e-03, -2.00631656e-03, 3.53828911e-03, - -2.96283513e-03, 3.20920302e-03, -5.43697737e-03 - ], [ - 5.30424621e-03, -2.17913301e-03, 3.59509978e-03, - -2.97106663e-03, 3.26450402e-03, -5.31189423e-03 - ]], [[ - -1.36440888e-03, -9.75572329e-04, -2.11284542e-03, - -1.84616144e-03, -5.31351101e-03, -9.12462734e-03 - ], [ - -1.41863467e-03, -1.11081311e-03, -1.94056751e-03, - -1.74311269e-03, -5.76282106e-03, -9.29267984e-03 - ], [ - -1.12129003e-03, -8.15156149e-04, -2.01535341e-03, - -1.89556007e-03, -5.04226238e-03, -9.37188603e-03 - ]], [[ - 1.55163277e-03, -4.01433324e-03, -9.77111282e-04, - 8.59013060e-04, -1.30598655e-03, 6.64281659e-03 - ], [ - 1.26811734e-03, -4.64518648e-03, -1.10593368e-03, - 8.41954607e-04, -2.11594440e-03, 6.58190623e-03 - ], [ - 1.02682540e-03, -4.43787826e-03, -1.43417739e-03, - 6.56281307e-04, -2.17684195e-03, 6.80128345e-03 - ]], [[ - -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, - -1.56512906e-04, 9.63474595e-05 - ], [ - -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, - 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 - ], [ - 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, - 2.73401442e-04, -2.69805576e-04 - ]]], - dtype=float32), - sample_id=array( - [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], - dtype=int32)) - + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00084602338), + sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=array( - [[ - -2.18942575e-02, -8.05099495e-03, -1.48526859e-03, - 1.61030665e-02, -1.37967104e-02, -7.57982396e-03, - -8.28088820e-03, -1.18743815e-02, 1.78839806e-02 - ], [ - 1.74203254e-02, -1.41929490e-02, -3.88103351e-03, - 3.19709182e-02, -3.54691371e-02, -2.14697979e-02, - -6.21709181e-03, -1.69324467e-03, -1.94495786e-02 - ], [ - -1.14536462e-02, 8.77809525e-03, -1.62965059e-02, - -1.39955431e-02, 1.34810507e-02, -1.04491040e-02, - 6.16097450e-03, -9.40943789e-03, -6.57613343e-03 - ], [ - -4.74765450e-02, -1.19113335e-02, -7.42897391e-05, - 4.10555862e-02, -1.36665069e-03, 2.11814232e-02, - -2.80444007e-02, -5.44504896e-02, -2.91908123e-02 - ], [ - 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, - 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, - -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 - ]], - dtype=float32), - h=array( - [[ - -1.09830676e-02, -3.97811923e-03, -7.55793473e-04, - 7.91002903e-03, -7.02103321e-03, -3.80714820e-03, - -4.21818346e-03, -6.05497835e-03, 8.92084371e-03 - ], [ - 8.68122280e-03, -7.16937613e-03, -1.88389909e-03, - 1.62679367e-02, -1.76828820e-02, -1.06622437e-02, - -3.07524228e-03, -8.46030540e-04, -9.99389403e-03 - ], [ - -5.71245840e-03, 4.50045895e-03, -8.07614625e-03, - -6.94804778e-03, 6.75577158e-03, -5.12094703e-03, - 3.06193763e-03, -4.61703911e-03, -3.23943049e-03 - ], [ - -2.37237271e-02, -5.88475820e-03, -3.73612711e-05, - 2.01791357e-02, -6.75620860e-04, 1.06695695e-02, - -1.42616741e-02, -2.69626491e-02, -1.45035451e-02 - ], [ - 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, - 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, - -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 - ]], - dtype=float32)), - attention=array( - [[ - 0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897, - -0.00612035 - ], [ - 0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645, - -0.00531189 - ], [ - -0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226, - -0.00937189 - ], [ - 0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684, - 0.00680128 - ], [ - 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, - -0.00026981 - ]], - dtype=float32), + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0039764317), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0019850098)), + attention=ResultSummary( + shape=(5, 6), dtype=dtype('float32'), mean=-0.00080144603), time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( @@ -604,132 +268,27 @@ class AttentionWrapperTest(test.TestCase): expected_final_output, expected_final_state, attention_mechanism_depth=9, - name="testLuongNotNormalized") + name='testLuongNotNormalized') def testLuongScaled(self): create_attention_mechanism = functools.partial( wrapper.LuongAttention, scale=True) expected_final_output = BasicDecoderOutput( - rnn_output=array( - [[[ - 1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03, - -4.17229906e-03, -6.65769773e-03 - ], [ - 1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03, - -3.15058464e-03, -6.34974428e-03 - ], [ - 2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03, - -3.96897132e-03, -6.12034509e-03 - ]], [[ - 4.76492243e-03, -1.97180966e-03, 3.29327444e-03, - -2.68205139e-03, 3.55229783e-03, -4.66645230e-03 - ], [ - 5.24956919e-03, -2.00631656e-03, 3.53828911e-03, - -2.96283513e-03, 3.20920302e-03, -5.43697737e-03 - ], [ - 5.30424621e-03, -2.17913301e-03, 3.59509978e-03, - -2.97106663e-03, 3.26450402e-03, -5.31189423e-03 - ]], [[ - -1.36440888e-03, -9.75572329e-04, -2.11284542e-03, - -1.84616144e-03, -5.31351101e-03, -9.12462734e-03 - ], [ - -1.41863467e-03, -1.11081311e-03, -1.94056751e-03, - -1.74311269e-03, -5.76282106e-03, -9.29267984e-03 - ], [ - -1.12129003e-03, -8.15156149e-04, -2.01535341e-03, - -1.89556007e-03, -5.04226238e-03, -9.37188603e-03 - ]], [[ - 1.55163277e-03, -4.01433324e-03, -9.77111282e-04, - 8.59013060e-04, -1.30598655e-03, 6.64281659e-03 - ], [ - 1.26811734e-03, -4.64518648e-03, -1.10593368e-03, - 8.41954607e-04, -2.11594440e-03, 6.58190623e-03 - ], [ - 1.02682540e-03, -4.43787826e-03, -1.43417739e-03, - 6.56281307e-04, -2.17684195e-03, 6.80128345e-03 - ]], [[ - -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, - -1.56512906e-04, 9.63474595e-05 - ], [ - -1.04306288e-04, -1.37411975e-04, 2.82689070e-05, - 6.56487318e-05, -1.48634164e-04, -1.84347919e-05 - ], [ - 1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04, - 2.73401442e-04, -2.69805576e-04 - ]]], - dtype=float32), - sample_id=array( - [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], - dtype=int32)) - + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00084602338), + sample_id=ResultSummary(shape=(5, 3), dtype=dtype('int32'), mean=2.0)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=array( - [[ - -2.18942575e-02, -8.05099495e-03, -1.48526859e-03, - 1.61030665e-02, -1.37967104e-02, -7.57982396e-03, - -8.28088820e-03, -1.18743815e-02, 1.78839806e-02 - ], [ - 1.74203254e-02, -1.41929490e-02, -3.88103351e-03, - 3.19709182e-02, -3.54691371e-02, -2.14697979e-02, - -6.21709181e-03, -1.69324467e-03, -1.94495786e-02 - ], [ - -1.14536462e-02, 8.77809525e-03, -1.62965059e-02, - -1.39955431e-02, 1.34810507e-02, -1.04491040e-02, - 6.16097450e-03, -9.40943789e-03, -6.57613343e-03 - ], [ - -4.74765450e-02, -1.19113335e-02, -7.42897391e-05, - 4.10555862e-02, -1.36665069e-03, 2.11814232e-02, - -2.80444007e-02, -5.44504896e-02, -2.91908123e-02 - ], [ - 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, - 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, - -5.05525898e-03, 2.19191350e-02, 1.62497871e-02 - ]], - dtype=float32), - h=array( - [[ - -1.09830676e-02, -3.97811923e-03, -7.55793473e-04, - 7.91002903e-03, -7.02103321e-03, -3.80714820e-03, - -4.21818346e-03, -6.05497835e-03, 8.92084371e-03 - ], [ - 8.68122280e-03, -7.16937613e-03, -1.88389909e-03, - 1.62679367e-02, -1.76828820e-02, -1.06622437e-02, - -3.07524228e-03, -8.46030540e-04, -9.99389403e-03 - ], [ - -5.71245840e-03, 4.50045895e-03, -8.07614625e-03, - -6.94804778e-03, 6.75577158e-03, -5.12094703e-03, - 3.06193763e-03, -4.61703911e-03, -3.23943049e-03 - ], [ - -2.37237271e-02, -5.88475820e-03, -3.73612711e-05, - 2.01791357e-02, -6.75620860e-04, 1.06695695e-02, - -1.42616741e-02, -2.69626491e-02, -1.45035451e-02 - ], [ - 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, - 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, - -2.54477374e-03, 1.11598391e-02, 7.94144534e-03 - ]], - dtype=float32)), - attention=array( - [[ - 0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897, - -0.00612035 - ], [ - 0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645, - -0.00531189 - ], [ - -0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226, - -0.00937189 - ], [ - 0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684, - 0.00680128 - ], [ - 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, - -0.00026981 - ]], - dtype=float32), + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0039764317), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0019850098)), + attention=ResultSummary( + shape=(5, 6), dtype=dtype('float32'), mean=-0.00080144603), time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( @@ -737,116 +296,27 @@ class AttentionWrapperTest(test.TestCase): expected_final_output, expected_final_state, attention_mechanism_depth=9, - name="testLuongScaled") + name='testLuongScaled') def testNotUseAttentionLayer(self): create_attention_mechanism = wrapper.BahdanauAttention expected_final_output = BasicDecoderOutput( - rnn_output=array( - [[[ - -0.24223405, -0.07791166, 0.15451428, 0.24738294, 0.30900395, - -0.24685201, 0.04992372, 0.18749543, -0.15878429, -0.13678923 - ], [ - -0.2422339, -0.07791159, 0.15451418, 0.24738279, 0.30900383, - -0.24685188, 0.04992369, 0.18749531, -0.15878411, -0.13678911 - ], [ - -0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475, - -0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965 - ]], [[ - 0.40035266, 0.12299616, -0.06085059, -0.09197108, 0.11368551, - -0.15302914, 0.00566157, -0.26885766, 0.08546552, 0.18886778 - ], [ - 0.40035242, 0.12299603, -0.06085056, -0.09197091, 0.11368536, - -0.15302882, 0.0056615, -0.26885763, 0.08546554, 0.18886763 - ], [ - 0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532, - -0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761 - ]], [[ - -0.4311333, 0.07519469, -0.01551808, 0.1913045, -0.02693807, - -0.21668895, -0.02155721, 0.0013397, 0.21180844, 0.25578707 - ], [ - -0.43113309, 0.07519454, -0.01551818, 0.19130446, -0.0269379, - -0.21668854, -0.021557, 0.00133975, 0.21180828, 0.25578681 - ], [ - -0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798, - -0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696 - ]], [[ - 0.07059932, 0.16451572, 0.01174669, 0.04646531, 0.1427598, - 0.0794456, -0.10852993, 0.15306188, 0.02151393, -0.05590061 - ], [ - 0.07059933, 0.16451576, 0.01174669, 0.04646532, 0.14275983, - 0.07944562, -0.10852996, 0.15306193, 0.02151394, -0.05590062 - ], [ - 0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599, - 0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065 - ]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., - 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], - dtype=float32), - sample_id=array( - [[4, 4, 4], [0, 0, 0], [9, 9, 9], [1, 1, 1], [0, 0, 0]], - dtype=int32)) - + rnn_output=ResultSummary( + shape=(5, 3, 10), dtype=dtype('float32'), mean=0.019546926), + sample_id=ResultSummary( + shape=(5, 3), dtype=dtype('int32'), mean=2.7999999999999998)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( - c=array( - [[ - -0.0181195, -0.01675365, -0.00510353, 0.01559796, - -0.01251448, -0.00437002, -0.01243257, -0.01720199, - 0.02274928 - ], [ - 0.01259979, -0.00839985, -0.00374037, 0.03136262, - -0.03486227, -0.02466441, -0.00496157, -0.00461032, - -0.02098336 - ], [ - -0.00781067, 0.00315682, -0.0138283, -0.01149793, - 0.00485562, -0.01343193, 0.0085915, -0.00632846, -0.01052086 - ], [ - -0.04184828, -0.01223641, 0.0009445, 0.03911434, 0.0043249, - 0.02220661, -0.03006243, -0.05418363, -0.02615385 - ], [ - 0.02282745, -0.00143833, 0.01918138, 0.00545033, - -0.01258384, -0.00303765, -0.00511231, 0.02166323, - 0.01638841 - ]], - dtype=float32), - h=array( - [[ - -0.00910065, -0.00827571, -0.00259689, 0.00764857, - -0.00635579, -0.00218579, -0.00633918, -0.00875511, - 0.01134532 - ], [ - 0.00626597, -0.004241, -0.00181303, 0.01597157, -0.0173375, - -0.01224921, -0.00244522, -0.00231299, -0.0107822 - ], [ - -0.00391383, 0.00162017, -0.00682621, -0.00570264, - 0.00244099, -0.00659772, 0.00426475, -0.00309861, - -0.00520028 - ], [ - -0.02087484, -0.00603306, 0.00047561, 0.01920062, - 0.00213875, 0.01115329, -0.0152659, -0.02687523, -0.01297523 - ], [ - 0.01138975, -0.00070959, 0.00986007, 0.0027323, -0.00645386, - -0.00152054, -0.00257339, 0.01103063, 0.00800891 - ]], - dtype=float32)), - attention=array( - [[ - -0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475, - -0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965 - ], [ - 0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532, - -0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761 - ], [ - -0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798, - -0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696 - ], [ - 0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599, - 0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065 - ], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], - dtype=float32), + c=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.0041728448), + h=ResultSummary( + shape=(5, 9), dtype=dtype('float32'), mean=-0.002085865)), + attention=ResultSummary( + shape=(5, 10), dtype=dtype('float32'), mean=0.019546915), time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=dtype('float32'), mean=0.125), alignment_history=()) self._testWithAttention( @@ -854,8 +324,8 @@ class AttentionWrapperTest(test.TestCase): expected_final_output, expected_final_state, attention_layer_size=None, - name="testNotUseAttentionLayer") + name='testNotUseAttentionLayer') -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 8d1c0c59e06..686a85e4e73 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -121,8 +121,13 @@ class _BaseAttentionMechanism(AttentionMechanism): 2. Preprocessing and storing the memory. """ - def __init__(self, query_layer, memory, memory_sequence_length=None, - memory_layer=None, check_inner_dims_defined=True, + def __init__(self, + query_layer, + memory, + probability_fn, + memory_sequence_length=None, + memory_layer=None, + check_inner_dims_defined=True, name=None): """Construct base AttentionMechanism class. @@ -132,6 +137,9 @@ class _BaseAttentionMechanism(AttentionMechanism): provided, the shape of `query` must match that of `memory_layer`. memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. + probability_fn: A `callable`. Converts the score and previous alignments + to probabilities. Its signature should be: + `probabilities = probability_fn(score, previous_alignments)`. memory_sequence_length (optional): Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. @@ -154,6 +162,10 @@ class _BaseAttentionMechanism(AttentionMechanism): "memory_layer is not a Layer: %s" % type(memory_layer).__name__) self._query_layer = query_layer self._memory_layer = memory_layer + if not callable(probability_fn): + raise TypeError("probability_fn must be callable, saw type: %s" % + type(probability_fn).__name__) + self._probability_fn = probability_fn with ops.name_scope( name, "BaseAttentionMechanismInit", nest.flatten(memory)): self._values = _prepare_memory( @@ -164,6 +176,8 @@ class _BaseAttentionMechanism(AttentionMechanism): else self._values) self._batch_size = ( self._keys.shape[0].value or array_ops.shape(self._keys)[0]) + self._alignments_size = (self._keys.shape[1].value or + array_ops.shape(self._keys)[1]) @property def memory_layer(self): @@ -185,6 +199,29 @@ class _BaseAttentionMechanism(AttentionMechanism): def batch_size(self): return self._batch_size + @property + def alignments_size(self): + return self._alignments_size + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return a tensor of all zeros. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return _zero_state_tensors(max_time, batch_size, dtype) + class LuongAttention(_BaseAttentionMechanism): """Implements Luong-style (multiplicative) attention scoring. @@ -208,6 +245,7 @@ class LuongAttention(_BaseAttentionMechanism): memory, memory_sequence_length=None, scale=False, + probability_fn=None, name="LuongAttention"): """Construct the AttentionMechanism mechanism. @@ -219,31 +257,43 @@ class LuongAttention(_BaseAttentionMechanism): in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. + probability_fn: (optional) A `callable`. Converts the score to + probabilities. The default is @{tf.nn.softmax}. Other options include + @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. + Its signature should be: `probabilities = probability_fn(score)`. name: Name to use when creating ops. """ # For LuongAttention, we only transform the memory layer; thus # num_units **must** match expected the query depth. + if probability_fn is None: + probability_fn = nn_ops.softmax + wrapped_probability_fn = lambda score, _: probability_fn(score) super(LuongAttention, self).__init__( query_layer=None, memory_layer=layers_core.Dense( num_units, name="memory_layer", use_bias=False), memory=memory, + probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, name=name) self._num_units = num_units self._scale = scale self._name = name - def __call__(self, query): + def __call__(self, query, previous_alignments): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. + previous_alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). Returns: - score: Tensor of dtype matching `self.values` and shape - `[batch_size, max_time]` (`max_time` is memory's `max_time`). + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). Raises: ValueError: If `key` and `query` depths do not match. @@ -281,7 +331,8 @@ class LuongAttention(_BaseAttentionMechanism): "attention_g", dtype=dtype, initializer=1.) score = g * score - return score + alignments = self._probability_fn(score, previous_alignments) + return alignments class BahdanauAttention(_BaseAttentionMechanism): @@ -311,6 +362,7 @@ class BahdanauAttention(_BaseAttentionMechanism): memory, memory_sequence_length=None, normalize=False, + probability_fn=None, name="BahdanauAttention"): """Construct the Attention mechanism. @@ -322,30 +374,42 @@ class BahdanauAttention(_BaseAttentionMechanism): in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. + probability_fn: (optional) A `callable`. Converts the score to + probabilities. The default is @{tf.nn.softmax}. Other options include + @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. + Its signature should be: `probabilities = probability_fn(score)`. name: Name to use when creating ops. """ + if probability_fn is None: + probability_fn = nn_ops.softmax + wrapped_probability_fn = lambda score, _: probability_fn(score) super(BahdanauAttention, self).__init__( query_layer=layers_core.Dense( num_units, name="query_layer", use_bias=False), memory_layer=layers_core.Dense( num_units, name="memory_layer", use_bias=False), memory=memory, + probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, name=name) self._num_units = num_units self._normalize = normalize self._name = name - def __call__(self, query): + def __call__(self, query, previous_alignments): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. + previous_alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). Returns: - score: Tensor of dtype matching `self.values` and shape - `[batch_size, max_time]` (`max_time` is memory's `max_time`). + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). """ with variable_scope.variable_scope(None, "bahdanau_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query @@ -373,20 +437,23 @@ class BahdanauAttention(_BaseAttentionMechanism): score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2]) - return score + alignments = self._probability_fn(score, previous_alignments) + return alignments class AttentionWrapperState( collections.namedtuple("AttentionWrapperState", - ("cell_state", "attention", "time", + ("cell_state", "attention", "time", "alignments", "alignment_history"))): """`namedtuple` storing the state of a `AttentionWrapper`. Contains: - - `cell_state`: The state of the wrapped `RNNCell`. + - `cell_state`: The state of the wrapped `RNNCell` at the previous time + step. - `attention`: The attention emitted at the previous time step. - `time`: int32 scalar containing the current time step. + - `alignments`: The alignment emitted at the previous time step. - `alignment_history`: (if enabled) a `TensorArray` containing alignment matrices from all time steps. Call `stack()` to convert to a `Tensor`. """ @@ -443,7 +510,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell): attention_layer_size=None, alignment_history=False, cell_input_fn=None, - probability_fn=None, output_attention=True, initial_cell_state=None, name=None): @@ -461,9 +527,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell): time major `TensorArray` on which you must call `stack()`). cell_input_fn: (optional) A `callable`. The default is: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. - probability_fn: (optional) A `callable`. Converts the score to - probabilities. The default is @{tf.nn.softmax}. Other options include - @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. output_attention: Python bool. If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If `False`, the output at each time step is @@ -495,13 +558,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell): raise TypeError( "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) - if probability_fn is None: - probability_fn = nn_ops.softmax - else: - if not callable(probability_fn): - raise TypeError( - "probability_fn must be callable, saw type: %s" - % type(probability_fn).__name__) if attention_layer_size is not None: self._attention_layer = layers_core.Dense( @@ -514,7 +570,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell): self._cell = cell self._attention_mechanism = attention_mechanism self._cell_input_fn = cell_input_fn - self._probability_fn = probability_fn self._output_attention = output_attention self._alignment_history = alignment_history with ops.name_scope(name, "AttentionWrapperInit"): @@ -553,6 +608,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]), attention=self._attention_size, + alignments=self._attention_mechanism.alignments_size, alignment_history=()) # alignment_history is sometimes a TensorArray def zero_state(self, batch_size, dtype): @@ -586,6 +642,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell): time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_size, batch_size, dtype), + alignments=self._attention_mechanism.initial_alignments( + batch_size, dtype), alignment_history=alignment_history) def call(self, inputs, state): @@ -637,8 +695,8 @@ class AttentionWrapper(core_rnn_cell.RNNCell): cell_output = array_ops.identity( cell_output, name="checked_cell_output") - score = self._attention_mechanism(cell_output) - alignments = self._probability_fn(score) + alignments = self._attention_mechanism( + cell_output, previous_alignments=state.alignments) # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = array_ops.expand_dims(alignments, 1) @@ -671,6 +729,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): time=state.time + 1, cell_state=next_cell_state, attention=attention, + alignments=alignments, alignment_history=alignment_history) if self._output_attention: