Fix the mean the description of the output shape, which should be [batch_size, 1, max_time] (#16642)
This commit is contained in:
parent
fe7bbd8565
commit
283f03c825
@ -331,7 +331,7 @@ def _luong_score(query, keys, scale):
|
||||
# batched matmul on:
|
||||
# [batch_size, 1, depth] . [batch_size, depth, max_time]
|
||||
# resulting in an output shape of:
|
||||
# [batch_time, 1, max_time].
|
||||
# [batch_size, 1, max_time].
|
||||
# we then squeeze out the center singleton dimension.
|
||||
score = math_ops.matmul(query, keys, transpose_b=True)
|
||||
score = array_ops.squeeze(score, [1])
|
||||
|
Loading…
Reference in New Issue
Block a user