Fix the mean the description of the output shape, which should be [batch_size, 1, max_time] (#16642)
This commit is contained in:
parent
fe7bbd8565
commit
283f03c825
@ -331,7 +331,7 @@ def _luong_score(query, keys, scale):
|
|||||||
# batched matmul on:
|
# batched matmul on:
|
||||||
# [batch_size, 1, depth] . [batch_size, depth, max_time]
|
# [batch_size, 1, depth] . [batch_size, depth, max_time]
|
||||||
# resulting in an output shape of:
|
# resulting in an output shape of:
|
||||||
# [batch_time, 1, max_time].
|
# [batch_size, 1, max_time].
|
||||||
# we then squeeze out the center singleton dimension.
|
# we then squeeze out the center singleton dimension.
|
||||||
score = math_ops.matmul(query, keys, transpose_b=True)
|
score = math_ops.matmul(query, keys, transpose_b=True)
|
||||||
score = array_ops.squeeze(score, [1])
|
score = array_ops.squeeze(score, [1])
|
||||||
|
Loading…
Reference in New Issue
Block a user