diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index c434113520f..259c8e08ad9 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -979,9 +979,9 @@ def _compute_attention(attention_mechanism, cell_output, previous_alignments, # alignments shape is # [batch_size, 1, memory_time] # attention_mechanism.values shape is - # [batch_size, memory_time, attention_mechanism.num_units] + # [batch_size, memory_time, memory_size] # the batched matmul is over memory_time, so the output shape is - # [batch_size, 1, attention_mechanism.num_units]. + # [batch_size, 1, memory_size]. # we then squeeze out the singleton dim. context = math_ops.matmul(expanded_alignments, attention_mechanism.values) context = array_ops.squeeze(context, [1])