Corrected dimension notes in attention_wrapper.py (#12107)
This commit is contained in:
parent
17e82cc274
commit
a9e2817c35
@ -979,9 +979,9 @@ def _compute_attention(attention_mechanism, cell_output, previous_alignments,
|
||||
# alignments shape is
|
||||
# [batch_size, 1, memory_time]
|
||||
# attention_mechanism.values shape is
|
||||
# [batch_size, memory_time, attention_mechanism.num_units]
|
||||
# [batch_size, memory_time, memory_size]
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, attention_mechanism.num_units].
|
||||
# [batch_size, 1, memory_size].
|
||||
# we then squeeze out the singleton dim.
|
||||
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
|
||||
context = array_ops.squeeze(context, [1])
|
||||
|
Loading…
Reference in New Issue
Block a user